import os
import time
from contextlib import nullcontext

import torch
from model import Transformer, ModelArgs
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from dataset import PretrainDataset
import json
import argparse
from sample_order import sample




def output_performance(model, test_loader):
    model.eval()
    eval_loss = torch.zeros(1, device=device)
    num_eval_example = 0
    perplexity_single = torch.zeros(1, device=device)
    for step, (X, Y) in enumerate(test_loader):
        X=X.to(device)
        Y=Y.to(device)       
        
        with torch.no_grad(): 
            logits = model(X, Y)
            
        eval_loss += model.last_loss.sum().item()
        num_eval_example += X.size(0)
        mini_perplexity = torch.exp(model.last_loss.sum() / X.size(0))
        perplexity_single += mini_perplexity


    eval_loss = eval_loss / num_eval_example
    perplexity_all = torch.exp(eval_loss)
    
    perplexity_all = perplexity_all.item()
    perplexity_single = perplexity_single.item() / len(test_loader)

    return perplexity_all, perplexity_single
                
def run():
    global config

        

    for order_idx, order in enumerate(sample[batch_num]):

        # init model
        model=init_model(os.path.join(init_out_dir, 'initial_model.pth'))
        model.to(device)
        optimizer = model.configure_optimizers(weight_decay, config['lr'], (beta1, beta2), device_type)
        # compile the model
        if compile:
            print("compiling the model... (takes a ~minute)")
            unoptimized_model = model
            model = torch.compile(model) # requires PyTorch 2.0
        # wrap model into DDP container
        if ddp:
            # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
            # construction time since NCCL does not support `ComplexFloat`
            prefix = "_orig_mod." if compile else ""
            model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
            model = DDP(model, device_ids=[ddp_local_rank])
            #

        
        for model_idx, batch_idx in enumerate(order):
            
            read_path = os.path.join(save_dir, 'precompute')
            ckpt_gamma = torch.load(os.path.join(read_path, 'gamma_' + str(model_idx) + '.pt'), map_location=device)
            ckpt_d_gamma = torch.load(os.path.join(read_path, 'd_gamma_' + str(model_idx) + '.pt'), map_location=device)
            if config['first_second_order'] == 'second':
                ckpt_d_2_gamma = torch.load(os.path.join(read_path, 'd_2_gamma_' + str(model_idx) + '.pt'), map_location=device)
            else:
                ckpt_d_2_gamma = None
                
            random_matrix = torch.load(os.path.join(out_dir, 'random_matrix.pt'), map_location=device)
            
            if model_idx != 0:
                old_model_path = os.path.join(os.path.join(init_out_dir, batch_num), 'checkpoints')
                old_model = init_model(os.path.join(old_model_path, 'model_batch_' + str(model_idx) +'.pth'))
                old_model.to(device)
            else: 
                old_model_path = os.path.join(os.path.join(init_out_dir, batch_num), 'checkpoints')
                old_model = init_model(os.path.join(old_model_path, 'model_batch_1.pth'))
                old_model.to(device)
            
            
            with torch.no_grad():
                for i, ((name1, param1), (name2, param2)) in enumerate(zip(model.named_parameters(), old_model.named_parameters())):
                    assert name1 == name2            
                    name = name1
                    if len(param1.shape) == 2:
                        result = torch.mm(ckpt_gamma[batch_idx][name], torch.linalg.pinv(random_matrix[name]))
                        result_d = torch.mm(ckpt_d_gamma[batch_idx][name], torch.linalg.pinv(random_matrix[name]))
                        if config['first_second_order'] == 'second':
                            result_d_2 = torch.mm(ckpt_d_2_gamma[batch_idx][name], torch.linalg.pinv(random_matrix[name]))
                        else:
                            result_d_2 = None

                    else:
                        result = ckpt_gamma[batch_idx][name]
                        result_d = ckpt_d_gamma[batch_idx][name]
                        if config['first_second_order'] == 'second':
                            result_d_2 = ckpt_d_2_gamma[batch_idx][name]
                        else:
                            result_d_2 = None
                    
                    if model_idx == 0:
                        new_param = param1 - config['lr'] * (result)
                    else:
                        if config['first_second_order'] == 'first':
                            new_param = param1 - config['lr'] * (result_d * (param1 - param2) + result)
                        else:
                            new_param = param1 - config['lr'] * (0.5 * result_d_2 * (param1 - param2)**2 + result_d * (param1 - param2) + result)
                    param1.data.copy_(new_param)
                with torch.no_grad():  
                    outlier_grad = 0.5
                    for p in model.parameters():
                        p.data = torch.clamp(p.data, min=-outlier_grad, max=outlier_grad)  

        perplexity_all, perplexity_single = output_performance(model, test_loader)
        print('model idx is ' + str(model_idx))
        print('The last batch_idx is ' + str(batch_idx))
        print('perplexity all is ' + str(perplexity_all))
        print('perplexity single is ' + str(perplexity_single))
        end_run_time = time.time()
        


def init_model(ckpt_path):
    # model init
    model_args = dict(
        dim=dim,
        n_layers=n_layers,
        n_heads=n_heads,
        n_kv_heads=n_heads,
        vocab_size=64793,
        multiple_of=multiple_of,
        max_seq_len=max_seq_len,
        dropout=dropout,
    )  # start with model_args from command line
    if init_from == "scratch":
        # init a new model from scratch
        print("Initializing a new model from scratch")
        gptconf = ModelArgs(**model_args)
        model = Transformer(gptconf)
        
    
    elif init_from == "resume":
        print(f"Resuming training from {ckpt_path}")
        # resume training from a checkpoint.
        
        state_dict = torch.load(ckpt_path, map_location=device)

        gptconf = ModelArgs(**model_args)
        model = Transformer(gptconf)

        unwanted_prefix = "_orig_mod."
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
        model.load_state_dict(state_dict)
    return model


# I/O
def load_config(config_path='configs/config.json'):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Counterfacutal LLM')
    parser.add_argument('--batch_num', type=str, default='8', 
                    help='the number of large batchs for true model')
    parser.add_argument('--device', type=str, default="cuda:2", 
                    help='the device number')
    parser.add_argument('--lr', type=float, default=1e-11)

    args = parser.parse_args()
    

    
    begin_run_time = time.time()
    
    
    config = load_config()

    out_dir = config["inference_out_dir"]  # Output directory
    init_out_dir = config['pretrain_out_dir']

    max_epoch = config["max_epoch"]  # Maximum number of training epochs
    eval_interval = config["eval_interval"]  # Evaluation interval (in number of epochs)
    log_interval = config["log_interval"]  # Log interval (in number of batches)
    eval_iters = config["eval_iters"]  # Number of iterations used for each evaluation
    eval_only = config["eval_only"]  # If True, the script exits after the first evaluation
    init_from = config["init_from"]  # Initialization mode: 'scratch', 'resume', or 'gpt2*'
    max_seq_len = config["max_seq_len"]  # Maximum sequence length
    dim = config["dim"]  # Embedding dimension of the model
    n_layers = config["n_layers"]  # Number of Transformer layers
    n_heads = config["n_heads"]  # Number of attention heads in the Transformer
    multiple_of = config["multiple_of"]  # Sequence length must be a multiple of this value
    dropout = config["dropout"]  # Dropout probability
    bias = config["bias"]  # Whether to use bias in linear layers
    # learning_rate = config["learning_rate"]  # Learning rate
    # learning_rate = 1e-11
    weight_decay = config["weight_decay"]  # Weight decay coefficient
    beta1 = config["beta1"]  # Beta1 parameter for Adam optimizer
    beta2 = config["beta2"]  # Beta2 parameter for Adam optimizer
    grad_clip = config["grad_clip"]  # Gradient clipping threshold
    decay_lr = config["decay_lr"]  # Whether to apply learning rate decay
    backend = config["backend"]  # Backend for distributed training, e.g., 'nccl', 'gloo'
    dtype = config["dtype"]  # Data type, e.g., 'float16', 'float32'
    compile = config["compile"]  # Whether to use PyTorch 2.0 compile acceleration

    batch_size = config['batch_size']  # Batch size for training

    
    device = args.device
    batch_num = args.batch_num
    config['device'] = device
    config['batch_num'] = batch_num
    config['min_max'] = args.min_max
    config['lr'] = args.lr
    config['wandb'] = args.wandb
    config['first_second_order'] = args.first_second_order
    train_data_path_list=config['train_data']
    test_data_path_list=config['test_data']
    # -----------------------------------------------------------------------------
    config_keys = [
        k
        for k, v in globals().items()
        if not k.startswith("_") and isinstance(v, (int, float, bool, str))
    ]
    
    save_dir = os.path.join(out_dir , batch_num)
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    
    
    ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
    
    if ddp:
        # Check if the operating system is Windows
        if os.name == 'nt':
            # Diff between backends: https://pytorch.org/docs/stable/distributed.html
            init_process_group(backend="gloo")
        else:
            # If the operating system is Linux based, os.name == 'posix'
            init_process_group(backend="nccl")
        ddp_rank = int(os.environ["RANK"])
        ddp_local_rank = int(os.environ["LOCAL_RANK"])
        ddp_world_size = int(os.environ["WORLD_SIZE"])
        device = f"cuda:{ddp_local_rank}"
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
        seed_offset = ddp_rank  # each process gets a different seed

    else:
        # if not ddp, we are running on a single gpu, and one process
        master_process = True
        seed_offset = 0
        ddp_world_size = 1

    if master_process:
        os.makedirs(out_dir, exist_ok=True)
    torch.manual_seed(1337 + seed_offset)
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    device_type = "cuda" if "cuda" in device else "cpu"  # for later use in torch.autocast
    # note: float16 data type will automatically use a GradScaler
    ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
    ctx = (
        nullcontext()
        if device_type == "cpu"
        else torch.cuda.amp.autocast()
    )
    #
    best_val_loss = 1e9
    #
    #-----init dataloader------
        
    test_ds = PretrainDataset(test_data_path_list, max_length=256)
    test_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=batch_size,
        pin_memory=False,
        drop_last=False,
        shuffle=False,        
        num_workers=0,
    )

    os.nice(0)
        
    run()
