import torch
import json

from src.sae import (
    BaseSAE,
    BatchTopKSAE,
    TopKSAE,
    VanillaSAE,
    JumpReLUSAE,
    SAEConfig,
    TrainingConfig,
    SAELoRAWrapper,
    KronSAE
)

from src.sae.config import SAEConfig

def load_trained_sae(path, sae_type = None):
    """Will load pretrained SAE weights from sae.pt and load configuration file from config.json."""
    
    # Load config
    with open(f'{path}/config.json', 'r') as file:
        config = json.load(file)

    # Load model from config
    available = {
        "batchtopk": BatchTopKSAE,
        "topk": TopKSAE,
        "vanilla": VanillaSAE,
        "jumprelu": JumpReLUSAE,
        "kronsae": KronSAE,
    }

    if sae_type is None:
        if config["sae_type"] in available.keys():
            sae_type = config['sae_type']
        else:
            raise ValueError(f"SAE type {config['sae_type']} is not implemented.")

    if "sum" in sae_type:
        config["cartesian_op"] = "sum"
    elif "mul" in sae_type:
        config["cartesian_op"] = "mul"

    config = SAEConfig.from_dict(config)
    sae = available[sae_type](config)
    
    sae.load_state_dict(torch.load(f"{path}/sae.pt", weights_only=True, map_location='cpu'), strict=False)
    
    return config, sae

def nanmax(tensor, dim=None, keepdim=False):
    min_value = torch.finfo(tensor.dtype).min
    output = tensor.nan_to_num(min_value).max(dim=dim, keepdim=keepdim)
    return output

def nanmin(tensor, dim=None, keepdim=False):
    max_value = torch.finfo(tensor.dtype).max
    output = tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim)
    return output