from transformer_lens import HookedTransformer
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.flop_counter import FlopCounterMode
import wandb
from sae_lens import SAE
import argparse
import yaml
import itertools
import os

from stitching.projection import inference_mode, BidirectionalStitchingLayer
from stitching.losses import next_token_cross_entropy_loss, kl_div_loss, mse_loss
import yaml
with open('global_config.yaml') as global_stream:
    global_cfg = yaml.safe_load(global_stream)
CACHE_DIR = global_cfg['CACHE_DIR']

def load_data(prefix, length=100000, truncation_length=512):
    tokenized_dataset = {}
    print(f"Loading data from data/{prefix}_tokenized_dataset_{length}_[train,test]_{truncation_length}.pt")
    for dataset_key in ['train', 'test']:
        tokenized_dataset[dataset_key] = torch.load(f'data/{prefix}_tokenized_dataset_{length}_{dataset_key}_{truncation_length}.pt', weights_only=True)
    return tokenized_dataset

def run_job(cfg):
    cache_dir = CACHE_DIR
    model_cfg = cfg['model']
    data_cfg = cfg['data']
    device = cfg['device']
    #A = inference_mode(HookedTransformer.from_pretrained(model_cfg['model_a_name'], cache_dir=cache_dir, device=device))
    #B = inference_mode(HookedTransformer.from_pretrained(model_cfg['model_b_name'], cache_dir=cache_dir, device=device))
    d_A = model_cfg['d_A']
    d_B = model_cfg['d_B']
    layer_A = model_cfg['layer_a'] # 4  # stop at layer
    layer_B = model_cfg['layer_b']

    project_name = f"stitch_training_{model_cfg['model_a_name']}_to_{model_cfg['model_b_name']}_bidirectional_mse"
    notes = ''
    if model_cfg['force_ortho']:
        notes += "_force_ortho"
    if model_cfg.get('method') is None:
        method = 'transpose'
    else:
        method = model_cfg.get('method')

    if cfg['use_wandb']:
        wandb.init(
            project=project_name,
            config=cfg,
            notes=notes
        )
    run_inverse = False if model_cfg.get('run_inverse') is None else model_cfg.get('run_inverse')
    use_bias1 = False if model_cfg.get('use_bias1') is None else model_cfg.get('use_bias1')
    use_bias2 = False if model_cfg.get('use_bias2') is None else model_cfg.get('use_bias2')

    projection_layer = BidirectionalStitchingLayer(
        d_A, d_B,
        device=device,
        method=method,
        use_bias1 = use_bias1,
        use_bias2 = use_bias2,
        force_orthogonal=model_cfg['force_ortho'],
    )

    if model_cfg.get('method') == 'separate_mat':
        params = itertools.chain(projection_layer.projection.parameters(), projection_layer.projection2.parameters())
    else:
        params = projection_layer.projection.parameters()

    optimizer = torch.optim.AdamW(params, lr=cfg['lr'], weight_decay=cfg['weight_decay'])
    approx_num_batches = data_cfg['size'] * 0.9 * data_cfg['truncation_length'] / data_cfg['token_batch_size']
    if cfg['use_scheduler']:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=approx_num_batches * cfg['epochs'], eta_min=cfg['lr'] / 10)

    best_val_loss = None
    if cfg['verbose']:
        print("Starting ", project_name, " with arguments:\n", cfg)
    # dump the yaml
    if cfg['use_wandb']:
        with open(f"checkpoints/{wandb.run.name}_cfg.yaml","w") as file:
            yaml.dump(cfg,file)
    total_iters = 0
    flopcounter = FlopCounterMode(display=False)
    flopcount = 0
    for epoch in range(cfg['epochs']):
        train_loss = []
        # instead of having a dataloader, we're just working with each individual file.
        for filename in os.listdir(os.path.join(data_cfg['activations_dir'], model_cfg['model_a_name'])):
            # look for file of corresponding number in model B's dir
            corresponding_file = f'{model_cfg['model_b_name']}_layer_{layer_B}_cached_activations_{os.path.basename(filename).split('_')[-1]}'
            print(filename, corresponding_file)
            corresponding_file = os.path.join(data_cfg['activations_dir'], f"{model_cfg['model_b_name']}/", corresponding_file)
            modelA_activations = torch.load(os.path.join(data_cfg['activations_dir'], f"{model_cfg['model_a_name']}/", filename), weights_only=True).float()
            modelB_activations = torch.load(corresponding_file, weights_only=True).float()
            assert(modelA_activations.shape[0] == modelB_activations.shape[0]), f"Shapes differ {modelA_activations.shape}, {modelB_activations.shape}"

            # now, train between them in batches
            for left_idx in tqdm(range(len(modelA_activations) // data_cfg['token_batch_size'])):
                real_left_idx = left_idx * data_cfg['token_batch_size']
                right_idx = real_left_idx + data_cfg['token_batch_size']
                
                if right_idx >= len(modelA_activations):
                    continue
                with flopcounter:
                    residual_stream_A = modelA_activations[real_left_idx:right_idx].to(device)
                    residual_stream_B = modelB_activations[real_left_idx:right_idx].to(device)
                    optimizer.zero_grad()
                    if run_inverse:
                        pred_resid_stream_A, pred_resid_stream_B, inv_A, inv_B = projection_layer(
                            residual_stream_A, residual_stream_B,
                            run_inverse=True
                        )
                    else:
                        pred_resid_stream_A, pred_resid_stream_B = projection_layer(
                            residual_stream_A, residual_stream_B,
                            run_inverse=False
                        )
                    mse_A = mse_loss(pred_resid_stream_A, residual_stream_A)
                    mse_B = mse_loss(pred_resid_stream_B, residual_stream_B)
                    mse = 0.5 * (mse_A + mse_B)
                    loss = mse
                    if run_inverse:
                        inv_loss_A = mse_loss(inv_A, residual_stream_A)
                        inv_loss_B = mse_loss(inv_B, residual_stream_B)
                        inv_mse = 0.5 * (inv_loss_A + inv_loss_B)
                        loss = loss + cfg['alpha_inv'] * inv_mse
                    
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, 1.0)
                    optimizer.step()
                    train_loss.append(loss.item())
                    step_result_dict = {
                        "train_mse": mse.item(),
                        "mse_A": mse_A.item(),
                        "mse_B": mse_B.item(),
                        "overall_loss": loss.item()
                    }
                    if run_inverse:
                        step_result_dict["inv_mse"] = inv_mse.item()
                        step_result_dict["inv_mse_A"] = inv_loss_A.item()
                        step_result_dict["inv_mse_B"] = inv_loss_B.item()
                        
                    if cfg['use_wandb']:
                        wandb.log(step_result_dict)
                        
                    if cfg['use_scheduler']:
                        scheduler.step()
                flopcount += flopcounter.get_total_flops()
                total_iters += 1
    # validate at final step
    if cfg['verbose']:
        print(f"Validating at {total_iters} steps")

    if cfg['use_wandb']:
        torch.save(projection_layer.projection.state_dict(), f"checkpoints/{project_name}/best_ckpt_{wandb.run.name}.pt")
        if method == 'separate_mat':
            torch.save(projection_layer.projection2.state_dict(), f"checkpoints/{project_name}/best_ckpt_mat2_{wandb.run.name}.pt")
        torch.save(projection_layer.beta, f"checkpoints/{project_name}/best_ckpt_{wandb.run.name}_beta.pt")
    else:
        torch.save(projection_layer.projection.state_dict(), f"checkpoints/{project_name}/best_ckpt.pt")
        if method == 'separate_mat':
            torch.save(projection_layer.projection2.state_dict(), f"checkpoints/{project_name}/best_ckpt_mat2.pt")
        torch.save(projection_layer.beta, f"checkpoints/{project_name}/best_ckpt_beta.pt")

    if cfg['use_wandb']:
        wandb.finish()
    formatted_flopcount = f"{flopcount:e}"
    with open(f"checkpoints/{project_name}/{layer_A}_to_{layer_B}_flops.txt", 'w') as f:
        f.write(str(flopcount))
        f.write('\n')
        f.write(formatted_flopcount)
    return True

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg_filename')
    args = parser.parse_args()
    with open(args.cfg_filename) as stream:
        cfg = yaml.safe_load(stream)
    run_job(cfg)
    
