import os
from transformer_lens import HookedTransformer
import torch
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.flop_counter import FlopCounterMode
from tqdm import tqdm
import wandb
import argparse
import yaml
import itertools

from stitching.stitch import inference_mode, BidirectionalStitchingLayer
from stitching.losses import next_token_cross_entropy_loss, kl_div_loss, mse_loss, get_ignore_mask, get_all_special_tokens

def load_data(prefix, length=100000, truncation_length=512):
    tokenized_dataset = {}
    print(f"Loading data from data/{prefix}_tokenized_dataset_{length}_[train,test]_{truncation_length}.pt")
    for dataset_key in ['train', 'test']:
        tokenized_dataset[dataset_key] = torch.load(f'data/{prefix}_tokenized_dataset_{length}_{dataset_key}_{truncation_length}.pt', weights_only=True)
    return tokenized_dataset

@torch.inference_mode()
def validate(A, B, layer_A, layer_B, projection_layer, val_dataloader_tokenized):
    val_loss_A = []
    val_loss_B = []
    kl_loss_A = []
    kl_loss_B = []
    mse_loss_A = []
    mse_loss_B = []
    spec_tokens = get_all_special_tokens(A.tokenizer)
    for j, sample in enumerate(val_dataloader_tokenized):
        tokenized = sample.to(projection_layer.device)
        residual_stream_A = A(tokenized, stop_at_layer=layer_A)
        residual_stream_B = B(tokenized, stop_at_layer=layer_B)
        pred_residual_stream_A, pred_residual_stream_B = projection_layer(residual_stream_A, residual_stream_B, run_inverse=False)

        true_logits_A = A(residual_stream_A, start_at_layer=layer_A)
        true_logits_B = B(residual_stream_B, start_at_layer=layer_B)
        pred_logits_A = A(pred_residual_stream_A, start_at_layer=layer_A)
        pred_logits_B = B(pred_residual_stream_B, start_at_layer=layer_B)
        ignore_indices = get_ignore_mask(sample, spec_tokens)
        mse_loss_A.append(mse_loss(pred_residual_stream_A[~ignore_indices], residual_stream_A[~ignore_indices]).item())
        mse_loss_B.append(mse_loss(pred_residual_stream_B[~ignore_indices], residual_stream_B[~ignore_indices]).item())
        
        kl_loss_A.append(kl_div_loss(pred_logits_A[~ignore_indices], true_logits_A[~ignore_indices], reduction='none').sum(axis=-1))
        kl_loss_B.append(kl_div_loss(pred_logits_B[~ignore_indices], true_logits_B[~ignore_indices], reduction='none').sum(axis=-1))
        
        val_loss_A.append(next_token_cross_entropy_loss(pred_logits_A, tokenized, ignore_index=spec_tokens, reduction='none'))
        val_loss_B.append(next_token_cross_entropy_loss(pred_logits_B, tokenized, ignore_index=spec_tokens, reduction='none'))

    val_results_dict = {
        "val_mse_loss_A": torch.tensor(mse_loss_A).mean().item(),
        "val_mse_loss_B": torch.tensor(mse_loss_B).mean().item(),
        "val_ce_loss": (torch.concatenate(val_loss_A) + torch.concatenate(val_loss_B)).mean().item(),
        "val_ce_loss_A": torch.concatenate(val_loss_A).mean().item(),
        "val_ce_loss_B": torch.concatenate(val_loss_B).mean().item(),
        "val_kl_loss_A": torch.concatenate(kl_loss_A).mean().item(),
        "val_kl_loss_B": torch.concatenate(kl_loss_B).mean().item()
    }
    return val_results_dict

def run_job(cfg, global_cfg):
    cache_dir = global_cfg['CACHE_DIR']
    model_cfg = cfg['model']
    data_cfg = cfg['data']
    device = cfg['device']
    tokenized_dataset = load_data(model_cfg['model_a_name'], data_cfg['size'], data_cfg['truncation_length'])
    A = inference_mode(HookedTransformer.from_pretrained(model_cfg['model_a_name'], cache_dir=cache_dir, device=device))
    B = inference_mode(HookedTransformer.from_pretrained(model_cfg['model_b_name'], cache_dir=cache_dir, device=device))
    layer_A = model_cfg['layer_a']
    layer_B = model_cfg['layer_b']
    PADDING_TOKEN = A.tokenizer.pad_token_id
    
    train_dataloader_tokenized = torch.utils.data.DataLoader(
        tokenized_dataset['train'],
        batch_size=data_cfg['batch_size'],
        shuffle=True
    )
    val_dataloader_tokenized = torch.utils.data.DataLoader(
        tokenized_dataset['test'][:len(tokenized_dataset['test']) // data_cfg['val_scale']],
        batch_size=data_cfg['batch_size'],
        shuffle=False
    )
    if cfg['verbose']:
        print(f"Loaded data of train {len(train_dataloader_tokenized)} and val {len(val_dataloader_tokenized)}")
    
    project_name = f"stitch_training_{model_cfg['model_a_name']}_to_{model_cfg['model_b_name']}_bidirectional_mse"
    checkpoints_dir = os.path.join('checkpoints/', f"{project_name}/")
    os.makedirs(checkpoints_dir, exist_ok=True)
    notes = ''
    if model_cfg['force_ortho']:
        notes += "_force_ortho"
    if model_cfg.get('method') is None:
        method = 'transpose'
    else:
        method = model_cfg.get('method')

    if cfg['use_wandb']:
        wandb.init(
            project=project_name,
            config=cfg,
            notes=notes
        )
    run_inverse = False if model_cfg.get('run_inverse') is None else model_cfg.get('run_inverse')
    use_bias1 = False if model_cfg.get('use_bias1') is None else model_cfg.get('use_bias1')
    use_bias2 = False if model_cfg.get('use_bias2') is None else model_cfg.get('use_bias2')

    projection_layer = BidirectionalStitchingLayer(
        A.cfg.d_model, B.cfg.d_model,
        device=device,
        method=method,
        use_bias1 = use_bias1,
        use_bias2 = use_bias2,
        force_orthogonal=model_cfg['force_ortho'],
    )

    if model_cfg.get('method') == 'separate_mat':
        params = itertools.chain(projection_layer.projection.parameters(), projection_layer.projection2.parameters())
    else:
        params = projection_layer.projection.parameters()
    optimizer = torch.optim.AdamW(params, lr=cfg['lr'], weight_decay=cfg['weight_decay'])
    if cfg['use_scheduler']:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader_tokenized) * cfg['epochs'], eta_min=cfg['lr'] / 10)

    best_val_loss = None
    if cfg['verbose']:
        print("Starting ", project_name, " with arguments:\n", cfg)
    # dump the yaml
    if cfg['use_wandb']:
        with open(os.path.join(checkpoints_dir, "{wandb.run.name}_cfg.yaml"),"w") as file:
            yaml.dump(cfg,file)
    total_iters = 0
    flopcounter = FlopCounterMode(display=False)
    flopcount = 0
    for epoch in range(cfg['epochs']):
        train_loss = []
        for i, sample in tqdm(enumerate(train_dataloader_tokenized), leave=True):
            with flopcounter:
                optimizer.zero_grad()
                tokenized = sample.to(projection_layer.device)
                with torch.inference_mode():
                    residual_stream_A = A(tokenized, stop_at_layer=layer_A)
                    residual_stream_B = B(tokenized, stop_at_layer=layer_B)
                residual_stream_A = residual_stream_A.clone()
                residual_stream_B = residual_stream_B.clone()
                
                if run_inverse:
                    pred_resid_stream_A, pred_resid_stream_B, inv_A, inv_B = projection_layer(
                        residual_stream_A, residual_stream_B,
                        run_inverse=True
                    )
                else:
                    pred_resid_stream_A, pred_resid_stream_B = projection_layer(
                        residual_stream_A, residual_stream_B,
                        run_inverse=False
                    )
    
                ignore_indices = (sample == PADDING_TOKEN)
                mse_A = mse_loss(pred_resid_stream_A[~ignore_indices], residual_stream_A[~ignore_indices])
                mse_B = mse_loss(pred_resid_stream_B[~ignore_indices], residual_stream_B[~ignore_indices])
                mse = 0.5 * (mse_A + mse_B)
                loss = mse
                if run_inverse:
                    inv_loss_A = mse_loss(inv_A[~ignore_indices], residual_stream_A[~ignore_indices])
                    inv_loss_B = mse_loss(inv_B[~ignore_indices], residual_stream_B[~ignore_indices])
                    inv_mse = 0.5 * (inv_loss_A + inv_loss_B)
                    loss = loss + cfg['alpha_inv'] * inv_mse
                    
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, 1.0)
                optimizer.step()
                train_loss.append(loss.item())
                step_result_dict = {
                    "train_mse": mse.item(),
                    "mse_A": mse_A.item(),
                    "mse_B": mse_B.item(),
                    "overall_loss": loss.item()
                }
                if run_inverse:
                    step_result_dict["inv_mse"] = inv_mse.item()
                    step_result_dict["inv_mse_A"] = inv_loss_A.item()
                    step_result_dict["inv_mse_B"] = inv_loss_B.item()
                    
                if cfg['use_wandb']:
                    wandb.log(step_result_dict)
                    
                if cfg['use_scheduler']:
                    scheduler.step()
            flopcount += flopcounter.get_total_flops()
            # end flopcounter
            if (total_iters % cfg['val_every']) == 0  and total_iters > 0:
                if cfg['verbose']:
                    print(f"Current learning rate: {scheduler.get_last_lr()}")
                    print(f"Validating at {total_iters} steps")
                val_results_dict = validate(A, B, layer_A, layer_B, projection_layer, val_dataloader_tokenized)
                
                if cfg['use_wandb']:
                    wandb.log(val_results_dict)
                cur_val_loss = val_results_dict['val_ce_loss']

                if (best_val_loss is None) or (cur_val_loss < best_val_loss):
                    print(f"{cur_val_loss} was better than {best_val_loss}, saving new checkpoint at iteration {total_iters}")
                    best_val_loss = cur_val_loss
                if cfg['use_wandb']:
                    torch.save(projection_layer.projection.state_dict(), os.path.join(checkpoints_dir, f"best_ckpt_{wandb.run.name}.pt"))
                    if method == 'separate_mat':
                        torch.save(projection_layer.projection2.state_dict(), os.path.join(checkpoints_dir, f"best_ckpt_mat2_{wandb.run.name}.pt"))
                    torch.save(projection_layer.beta, os.path.join(checkpoints_dir, f"best_ckpt_{wandb.run.name}_beta.pt"))
                else:
                    torch.save(projection_layer.projection.state_dict(), os.path.join(checkpoints_dir, "best_ckpt.pt"))
                    if method == 'separate_mat':
                        torch.save(projection_layer.projection2.state_dict(), os.path.join(checkpoints_dir, "best_ckpt_mat2.pt"))
                    torch.save(projection_layer.beta, os.path.join(checkpoints_dir, "best_ckpt_beta.pt"))
            total_iters += 1
    # validate at final step
    if cfg['verbose']:
        print(f"Validating at {total_iters} steps")
    val_results_dict = validate(A, B, layer_A, layer_B, projection_layer, val_dataloader_tokenized)
    if cfg['use_wandb']:
        wandb.log(val_results_dict)
    cur_val_loss = val_results_dict['val_ce_loss']

    if (best_val_loss is None) or (cur_val_loss < best_val_loss):
        print(f"{cur_val_loss} was better than {best_val_loss}, saving new checkpoint at iteration {total_iters}")
        best_val_loss = cur_val_loss
    if cfg['use_wandb']:
        torch.save(projection_layer.projection.state_dict(), os.path.join(checkpoints_dir, f"best_ckpt_{wandb.run.name}.pt"))
        if method == 'separate_mat':
            torch.save(projection_layer.projection2.state_dict(), os.path.join(checkpoints_dir, f"best_ckpt_mat2_{wandb.run.name}.pt"))
        torch.save(projection_layer.beta, os.path.join(checkpoints_dir, f"best_ckpt_{wandb.run.name}_beta.pt"))
    else:
        torch.save(projection_layer.projection.state_dict(), os.path.join(checkpoints_dir, "best_ckpt.pt"))
        if method == 'separate_mat':
            torch.save(projection_layer.projection2.state_dict(), os.path.join(checkpoints_dir, "best_ckpt_mat2.pt"))
        torch.save(projection_layer.beta, os.path.join(checkpoints_dir, "best_ckpt_beta.pt"))

    if cfg['use_wandb']:
        wandb.finish()
    formatted_flopcount = f"{flopcount:e}"
    with open(os.path.join(checkpoints_dir, f"{layer_A}_to_{layer_B}_flops.txt"), 'w') as f:
        f.write(str(flopcount))
        f.write('\n')
        f.write(formatted_flopcount)
    return True

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg_filename')
    args = parser.parse_args()
    with open(args.cfg_filename) as stream:
        cfg = yaml.safe_load(stream)
    with open('global_config.yaml') as global_stream:
        global_cfg = yaml.safe_load(global_stream)
    run_job(cfg, global_cfg)
    
