import os
import yaml
with open('global_config.yaml') as global_stream:
    global_cfg = yaml.safe_load(global_stream)
CACHE_DIR = global_cfg['CACHE_DIR']
os.environ["HF_TOKEN"] = global_cfg['hf_access_token']
from transformer_lens import HookedTransformer
import yaml
import torch
import argparse
from tqdm import tqdm
from stitching.losses import next_token_cross_entropy_loss
from stitching.stitching_utils import open_experiment
import json
import numpy as np
import gc
device = 'cuda'

torch.set_grad_enabled(False)


def get_intervention(intervention, intervention_kwargs, checkpoints_dir):
    if intervention == 'identity' or intervention is None:
        return (lambda x: x)
    elif intervention == 'stitch':
        P, Pinv, beta, bias, biasinv = open_experiment(intervention_kwargs['d_A'], intervention_kwargs['d_B'], checkpoints_dir, intervention_kwargs['transfer_id'], biases=True, device=device)
        if intervention_kwargs['start_model'] == 'A':
            return (lambda x: x @ P + bias)
        elif intervention_kwargs['start_model'] == 'B':
            return (lambda x: x @ Pinv + biasinv)
    elif intervention == 'zero':
        return (lambda x: torch.zeros_like(x))
    elif intervention == 'inverse':
        P, Pinv, beta, bias, biasinv = open_experiment(intervention_kwargs['d_A'], intervention_kwargs['d_B'], checkpoints_dir, intervention_kwargs['transfer_id'], biases=True, device=device)
        if intervention_kwargs['start_model'] == 'A':
            return (lambda x: (x @ P + bias) @ Pinv + biasinv)
        elif intervention_kwargs['start_model'] == 'B':
            return (lambda x: (x @ Pinv + biasinv) @ P + bias)
        
@torch.inference_mode()
def compute_losses(interventions, intervention_kwargs_list, acts_dir, modelA_name, modelB_name, layer_A, layer_B, A_ends, n_batches_per_file, batch_size):
    if A_ends:
        model_name = modelA_name
        layer = layer_A
        end_model = 'A'
    else:
        model_name = modelB_name
        layer = layer_B
        end_model = 'B'

    if '9b' in model_name:
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32
    print("Loading", model_name, "as the ending model")
    model = HookedTransformer.from_pretrained(model_name=model_name, device=device, cache_dir=CACHE_DIR, torch_dtype=torch_dtype)
    losses = {
        f"{intervention}_{intervention_kwargs['start_model']}_to_{end_model}": [] for intervention, intervention_kwargs in zip(interventions, intervention_kwargs_list)
    }
    project_name = f"stitch_training_{modelA_name}_to_{modelB_name}_bidirectional_mse"
    checkpoints_dir = os.path.join('checkpoints/', f"{project_name}/")
    for intervention, intervention_kwargs in zip(interventions, intervention_kwargs_list):    
        intervention_func = get_intervention(intervention, intervention_kwargs, checkpoints_dir)
        start_model = intervention_kwargs['start_model']
        for filename in os.listdir(os.path.join(acts_dir, f"{modelA_name}/")):
            print(intervention, filename)
            tokens = torch.load(
                os.path.join(
                    acts_dir,
                    f'tokens_{os.path.basename(filename).split('_')[-1]}'
                ),
                weights_only=True
            )
            activations_A = torch.load(
                os.path.join(
                    acts_dir,
                    f"{modelA_name}/",
                    filename
                ), weights_only=True
            ).float() # (b, s, q)
            activations_B = torch.load(
                os.path.join(
                    acts_dir,
                    f"{modelB_name}/",
                    f'{modelB_name}_layer_{layer_B}_cached_activations_{os.path.basename(filename).split('_')[-1]}'
                ),
                weights_only=True
            ).float()
            print(tokens.shape, activations_A.shape, activations_B.shape)
            dataloader = torch.utils.data.DataLoader(
                torch.utils.data.TensorDataset(activations_A, activations_B, tokens),
                batch_size=batch_size,
                shuffle=False
            )
            for i, sample in tqdm(enumerate(dataloader)):
                # INSESRT INTERVENTION IF NEEDED
                batch_tokens = sample[2].to(device)
                if start_model == 'A':
                    sample = sample[0].to(device)
                else:
                    sample = sample[1].to(device)
                sample = intervention_func(sample)

                ignore_tokens = list(set([model.tokenizer.bos_token_id, model.tokenizer.eos_token_id, model.tokenizer.pad_token_id, 0]))
                logits = model(sample, start_at_layer=layer)
                losses[f"{intervention}_{start_model}_to_{end_model}"].append(next_token_cross_entropy_loss(logits, batch_tokens, ignore_index=ignore_tokens, reduction='none').flatten().cpu())
    
    for (k,v) in losses.items():
        if len(v) == 0:
            losses[k] = None
        else:
            losses[k] = torch.concatenate(v).mean().item()
    return losses

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg_filename')
    args = parser.parse_args()
    with open(args.cfg_filename) as stream:
        cfg = yaml.safe_load(stream)
    print("Config", cfg)
    
    torch.manual_seed(cfg['seed'])
    losses = compute_losses(
        cfg['interventions'], 
        cfg['interventions_kwargs_list'],
        cfg['acts_dir'],
        cfg['modelA_name'],
        cfg['modelB_name'],
        cfg['layer_A'],
        cfg['layer_B'],
        cfg['A_ends'],
        cfg['n_batches_per_file'],
        cfg['batch_size']
    )
    save_file = cfg['modelA_name'] if cfg['A_ends'] else cfg['modelB_name']
    with open(f"results/{save_file}_ends_metrics.json", "w") as file:
        json.dump(losses, file)
    
