import torch
import os
import numpy as np

@torch.inference_mode()
def open_experiment(dim_A, dim_B, checkpoints_dir, exp_name, biases=False, device='cpu'):
    files = os.listdir(checkpoints_dir)
    beta = torch.ones(1)
    mat2 = None
    mat = None
    bias = None
    bias2 = None
    for filename in files:
        if not(exp_name in filename) or not(filename.endswith('.pt')):
            continue
        filename = os.path.join(checkpoints_dir, filename)
        if 'beta' in filename:
            # we have a beta term
            beta = torch.load(filename, weights_only=True).cpu()
        elif 'mat2' in filename:
            P_layer2 = torch.nn.Linear(dim_B, dim_A, bias=biases, device='cpu')
            P_layer2.load_state_dict(torch.load(filename, weights_only=True))
            mat2 = P_layer2.weight.T
            if biases:
                bias2 = P_layer2.bias
        else:
            # original file
            P_layer = torch.nn.Linear(dim_A, dim_B, bias=biases, device='cpu')
            if 'force_ortho' in filename:
                P_layer = torch.nn.utils.parametrizations.orthogonal(
                    P_layer
                )
            P_layer.load_state_dict(torch.load(filename, weights_only=True))
            mat = P_layer.weight.T
            if biases:
                bias = P_layer.bias
    if mat == None:
        raise ValueError(f"{exp_name} not found in {checkpoints_dir}")
    if mat2 == None:
        # then it's just the transpose
        mat2 = P_layer.weight
    mat = mat * beta
    mat2 = mat2 / beta
    if not(biases):
        return mat.to(device), mat2.to(device), beta.to(device)
    else:  
        return mat.to(device), mat2.to(device), beta.to(device), bias.to(device), bias2.to(device)

def load_activation_store(pfx, size, d_sae, subdir=''):
    baseline = {}
    for i in np.arange(0, d_sae, size):
        loaded_file = np.load(f'activations_store/{subdir}{pfx}_size_{size}_batch_{i}.npz')
        if len(baseline.keys()) == 0:
            baseline = {k: [loaded_file[k]] for k in loaded_file.keys()}
        else:
            for k in baseline.keys():
                baseline[k].append(loaded_file[k])
    baseline = {k: np.concatenate(v) for (k,v) in baseline.items()}
    return baseline