import os
import time
from contextlib import nullcontext

import torch
from model import Transformer, ModelArgs
from torch.distributed import destroy_process_group, init_process_group
from dataset import PretrainDataset
import logging
import json
import argparse




def construct_random_matrix(model):
    K = 80
    random_matrix = {}
    for name, param in model.named_parameters(): 
        if len(param.shape) == 2:
            shape = param.shape
            d_in = shape[1]
            P =  torch.randn(d_in,K).to(device) * (1 / K) ** 0.5

            random_matrix[name] = P
    
    torch.save(random_matrix, os.path.join(out_dir, 'random_matrix.pt'))
    
    return random_matrix

def each_model(ckpt_idx, start_time, additional_time, last_param, last_grad_list, last_d_mt, last_d_vt, last_second_grad_list, last_d_2_mt, last_d_2_vt):
    last_param= {k: v.to(device) for k, v in last_param.items()}
    last_d_mt = {k: v.to(device) for k, v in last_d_mt.items()}
    last_d_vt = {k: v.to(device) for k, v in last_d_vt.items()}
    last_d_2_mt = {k: v.to(device) for k, v in last_d_2_mt.items()}
    last_d_2_vt = {k: v.to(device) for k, v in last_d_2_vt.items()}
    
    last_param_copy, last_grad_copy, last_d_mt_copy, last_d_vt_copy, last_second_grad_copy, last_d_2_mt_copy, last_d_2_vt_copy = {}, {}, {}, {}, {}, {}, {}
    torch.cuda.empty_cache()
    if ckpt_idx == 0:
        model = init_model(os.path.join(init_out_dir, 'initial_model.pth'))
        model.to(device)
        optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
    else:
        ckpt_path = os.path.join(os.path.join(init_out_dir, args.batch_num), f"checkpoints/model_batch_{ckpt_idx}.pth")
        model = init_model(ckpt_path)
        model.to(device)
        optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)

    
    ckpt_gamma, ckpt_d_gamma, ckpt_d_2_gamma = [], [], []
    for step, (X, Y) in enumerate(new_loader):
        X=X.to(device)
        Y=Y.to(device)

        
        lr = learning_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        # and using the GradScaler if data type is float16
        #for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            model.require_backward_grad_sync = 0 == gradient_accumulation_steps - 1
        with ctx:
            logits = model(X, Y)
            loss = model.last_loss

        loss.backward(retain_graph=False)


        if step % log_interval == 0:
            spend_time = time.time()-start_time
            logger.info(
                    'Checkpoint:[{}/{}] ({}/{}) loss:{:.3f} epoch_Time:{}min:'.format(
                        ckpt_idx,
                        int(args.batch_num) - 1,
                        step, 
                        iter_per_epoch,
                        loss.item(), 
                        spend_time / (step+1) * iter_per_epoch // 60 - spend_time // 60))

        with torch.no_grad():
            if (step + 1) % gradient_accumulation_steps == 0 or step == len(new_loader) - 1:

                batch_t_t_b_t, batch_d_t_t_b_t, batch_d_2_t_t_b_t = {}, {}, {}
                f = open('test.txt', 'a')
                for name, param in model.named_parameters():


                    state = optimizer.state[param]
                    if 'exp_avg' not in state:
                        state['exp_avg'] = torch.zeros_like(param.data)
                        state['exp_avg_sq'] = torch.zeros_like(param.data)

                    
                    now_grad = param.grad.clone().detach()
                    now_param = param.clone().detach()
                    if ckpt_idx + 1 ==  (step + 1) // gradient_accumulation_steps:
                        state['exp_avg'] = beta1 * state['exp_avg'] + (1 - beta1) * now_grad
                        state['exp_avg_sq'] = beta2 * state['exp_avg_sq'] + (1 - beta2) * (now_grad ** 2)
                    

                    m_t, v_t = state['exp_avg'], state['exp_avg_sq']

                    epsilon = 1e-2
                    factor = 1e3
                    
                    result = m_t / (torch.sqrt(v_t) + epsilon)
                    if len(result.shape) == 2:
                        result = torch.mm(result, random_matrix[name])
                    
                    batch_t_t_b_t[name] = result.detach().cpu()
                    
                    del result

                    if ckpt_idx == 0:
                        second_order_grad = torch.zeros_like(now_param)
                        
                        d_m_t_theta = (1 - beta1) * second_order_grad / (1-beta1**(ckpt_idx + 1))
                        d_sqrt_v_t_theta = 2 * (1-beta2) * now_grad * second_order_grad / ((1-beta2**(ckpt_idx + 1)) * 2 * torch.sqrt(v_t) + epsilon)
                        d_v_t_theta = d_sqrt_v_t_theta * 2 * torch.sqrt(v_t)
                        
                        add_time_1 = time.time()
                        third_order_grad = torch.zeros_like(now_param)
                        d_2_m_t_theta = (1-beta1) * third_order_grad / (1-beta1**(ckpt_idx + 1))
                        d_2_sqrt_v_t_theta = 2 * (1-beta2) * (second_order_grad * second_order_grad + now_grad * third_order_grad)  / ((1-beta2**(ckpt_idx + 1)) * 2 * torch.sqrt(v_t) + epsilon) \
                                                - (2 * (1-beta2) * now_grad * second_order_grad) * d_v_t_theta / ((1-beta2**(ckpt_idx + 1)) * 4 * torch.pow(v_t, 1.5) + epsilon)
                        
                        d_2_v_t_theta = d_2_sqrt_v_t_theta * 2 * torch.sqrt(v_t) + d_v_t_theta**2 / (2 * v_t + epsilon)


                        b_idx = (step + 1) // gradient_accumulation_steps - 1
                        
                        if len(now_grad.shape) == 2:
                            last_grad_list[b_idx][name] = torch.mm(now_grad, random_matrix[name])
                            last_second_grad_list[b_idx][name] = torch.mm(second_order_grad, random_matrix[name])
                        else:
                            last_grad_list[b_idx][name] = now_grad
                            last_second_grad_list[b_idx][name] = second_order_grad

                        add_time_2 = time.time()
                        additional_time += add_time_2 - add_time_1

                    else:
                        b_idx = (step + 1) // gradient_accumulation_steps - 1
                        if len(now_grad.shape) == 2:
                            last_grad = torch.mm(last_grad_list[b_idx][name], torch.linalg.pinv(random_matrix[name]))
                            last_second_grad = torch.mm(last_second_grad_list[b_idx][name], torch.linalg.pinv(random_matrix[name]))
                        else:
                            last_grad = last_grad_list[b_idx][name]
                            last_second_grad = last_second_grad_list[b_idx][name]


                        second_order_grad = torch.tanh((now_grad - last_grad) / (now_param - last_param[name] + epsilon)) / factor
                        d_m_t_theta = (last_d_mt[name] * beta1 + (1 - beta1) * second_order_grad) / (1-beta1**(ckpt_idx + 1))
                        d_sqrt_v_t_theta = (last_d_vt[name] * beta2 + 2 * (1-beta2) * now_grad * second_order_grad) / ((1-beta2**(ckpt_idx + 1)) * 2 * torch.sqrt(v_t) + epsilon)
                        d_v_t_theta = d_sqrt_v_t_theta * 2 * torch.sqrt(v_t)
                        
                        add_time_1 = time.time()
                        third_order_grad = torch.tanh((second_order_grad - last_second_grad) / (now_param - last_param[name] + epsilon)) / factor
                        d_2_m_t_theta = (last_d_2_mt[name] * beta1 + (1-beta1) * third_order_grad) / (1-beta1**(ckpt_idx + 1))
                        d_2_sqrt_v_t_theta = (last_d_2_vt[name] * beta2 + 2 * (1-beta2) * (second_order_grad * second_order_grad + now_grad * third_order_grad))  / ((1-beta2**(ckpt_idx + 1)) * 2 * torch.sqrt(v_t) + epsilon) \
                                                - (beta2 * last_d_vt[name] + 2 * (1-beta2) * now_grad * second_order_grad) * d_v_t_theta / ((1-beta2**(ckpt_idx + 1)) * 4 * torch.pow(v_t, 1.5) + epsilon)
                        d_2_v_t_theta = d_2_sqrt_v_t_theta * 2 * torch.sqrt(v_t) + d_v_t_theta**2 / (2 * v_t + epsilon)
                        
                        if len(now_grad.shape) == 2:
                            last_grad_list[b_idx][name] = torch.mm(now_grad, random_matrix[name])
                            last_second_grad_list[b_idx][name] = torch.mm(second_order_grad, random_matrix[name])
                        else:
                            last_grad_list[b_idx][name] = now_grad
                            last_second_grad_list[b_idx][name] = second_order_grad

                        add_time_2 = time.time()
                        additional_time += add_time_2 - add_time_1

                    result_d = (d_m_t_theta * (torch.sqrt(v_t)+epsilon) - d_sqrt_v_t_theta * m_t) / (torch.sqrt(v_t) + epsilon)**2
                    result_d = torch.nan_to_num(result_d, nan=0.0)

                    add_time_1 = time.time()

                    result_d_2 = (d_2_m_t_theta * (torch.sqrt(v_t) + epsilon) - d_2_sqrt_v_t_theta * m_t - 2 * d_sqrt_v_t_theta * d_m_t_theta + 2 * d_sqrt_v_t_theta**2 * (m_t / (torch.sqrt(v_t) + epsilon))) / (torch.sqrt(v_t)+epsilon)**2
                    result_d_2 = torch.nan_to_num(result_d_2, nan=0.0)

                    add_time_2 = time.time()
                    additional_time += add_time_2 - add_time_1
                    
                    if len(result_d.shape) == 2:
                        result_d = torch.mm(result_d, random_matrix[name])
                        result_d_2 = torch.mm(result_d_2, random_matrix[name])

                    batch_d_t_t_b_t[name] = result_d.detach().cpu()
                    batch_d_2_t_t_b_t[name] = result_d_2.detach().cpu()
                    del result_d
                    del result_d_2
                    
                    if ckpt_idx + 1 ==  (step + 1) // gradient_accumulation_steps:
                        last_d_mt_copy[name] = d_m_t_theta.clone().cpu()
                        last_d_vt_copy[name] = d_v_t_theta.clone().cpu()
                        last_param_copy[name] = now_param.clone().cpu()
                        last_d_2_mt_copy[name] = d_2_m_t_theta.clone().cpu()
                        last_d_2_vt_copy[name] = d_2_v_t_theta.clone().cpu()

                f.close()
                ckpt_gamma.append(batch_t_t_b_t)
                ckpt_d_gamma.append(batch_d_t_t_b_t)
                ckpt_d_2_gamma.append(batch_d_2_t_t_b_t)

                optimizer.zero_grad(set_to_none=True)
                del loss
                del batch_t_t_b_t, batch_d_t_t_b_t, second_order_grad, d_m_t_theta, d_sqrt_v_t_theta, d_v_t_theta
                del batch_d_2_t_t_b_t, third_order_grad, d_2_m_t_theta, d_2_sqrt_v_t_theta, d_2_v_t_theta

                torch.cuda.empty_cache()
                
        del X, Y 


    write_path = os.path.join(save_dir, 'precompute')
    torch.save(ckpt_gamma, os.path.join(write_path, f'gamma_{ckpt_idx}.pt'))
    torch.save(ckpt_d_gamma, os.path.join(write_path, f'd_gamma_{ckpt_idx}.pt'))
    
    add_time_1 = time.time()
    torch.save(ckpt_d_2_gamma, os.path.join(write_path, f'd_2_gamma_{ckpt_idx}.pt'))
    add_time_2 = time.time()
    additional_time += add_time_2 - add_time_1

    print(f'Successfully store gamma_{ckpt_idx}.pt, d_gamma_{ckpt_idx}.pt and d_2_gamma_{ckpt_idx} !')

    del ckpt_d_gamma, ckpt_gamma
    del model
    del optimizer
    torch.cuda.empty_cache()
    return additional_time, last_param_copy, last_grad_list, last_d_mt_copy, last_d_vt_copy, last_second_grad_list, last_d_2_mt_copy, last_d_2_vt_copy

def compute_gradient(epoch):
    additional_time = 0
    start_time=time.time()
    
    print('The len of batches after cutting is ' + str(len(new_loader)) + 
          ", which is equal to " + str(gradient_accumulation_steps) + str(' * ') + args.batch_num)
 
    state_vars = {
        "last_param": {},
        "last_grad_list": [{} for _ in range(int(args.batch_num))],
        "last_d_mt": {},
        "last_d_vt": {},
        "last_second_grad_list": [{} for _ in range(int(args.batch_num))],
        "last_d_2_mt": {},
        "last_d_2_vt": {}
    }    

    for ckpt_idx in range(int(args.batch_num)):
        additional_time, *state_values = each_model(ckpt_idx, start_time, additional_time, *state_vars.values())
        state_vars.update(dict(zip(state_vars.keys(), state_values)))
        
    return additional_time


       
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "a")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger
            

def init_model(ckpt_path):
    # model init
    model_args = dict(
        dim=dim,
        n_layers=n_layers,
        n_heads=n_heads,
        n_kv_heads=n_heads,
        vocab_size=64793,
        multiple_of=multiple_of,
        max_seq_len=max_seq_len,
        dropout=dropout,
    ) 
    if init_from == "scratch":
        # init a new model from scratch
        print("Initializing a new model from scratch")
        gptconf = ModelArgs(**model_args)
        model = Transformer(gptconf)
        
    
    elif init_from == "resume":
        print(f"Resuming training from {ckpt_path}")
        # resume training from a checkpoint.
        
        state_dict = torch.load(ckpt_path, map_location=device)

        gptconf = ModelArgs(**model_args)
        model = Transformer(gptconf)

        unwanted_prefix = "_orig_mod."
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
        model.load_state_dict(state_dict)
    return model


# I/O
def load_config(config_path='configs/config.json'):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Counterfacutal LLM')
    parser.add_argument('--batch_num', type=str, default='8', 
                    help='the number of large batchs for true model')
    parser.add_argument('--device', type=str, default="cuda:2", 
                    help='the device number')
    args = parser.parse_args()

    
    
    begin_run_time = time.time()
    
    
    device = args.device
    
    config = load_config()
    


    out_dir = config["inference_out_dir"]  # Output directory
    init_out_dir = config['pretrain_out_dir']

    max_epoch = config["max_epoch"]  # Maximum number of training epochs
    eval_interval = config["eval_interval"]  # Evaluation interval (in epochs)
    log_interval = config["log_interval"]  # Log interval (in number of batches)
    eval_iters = config["eval_iters"]  # Number of iterations used for each evaluation
    eval_only = config["eval_only"]  # If True, the script exits after the first evaluation
    init_from = config["init_from"]  # Initialization method: 'scratch', 'resume', or 'gpt2*'
    max_seq_len = config["max_seq_len"]  # Maximum sequence length
    dim = config["dim"]  # Embedding dimension of the model
    n_layers = config["n_layers"]  # Number of layers in the Transformer model
    n_heads = config["n_heads"]  # Number of attention heads in the Transformer
    multiple_of = config["multiple_of"]  # Constraint: sequence length must be a multiple of this value
    dropout = config["dropout"]  # Dropout probability
    bias = config["bias"]  # Whether to use bias in linear layers
    learning_rate = config["learning_rate"]  # Learning rate
    weight_decay = config["weight_decay"]  # Weight decay (L2 regularization)
    beta1 = config["beta1"]  # Beta1 parameter for Adam optimizer
    beta2 = config["beta2"]  # Beta2 parameter for Adam optimizer
    grad_clip = config["grad_clip"]  # Gradient clipping threshold
    decay_lr = config["decay_lr"]  # Whether to apply learning rate decay
    backend = config["backend"]  # Backend for distributed training, e.g., 'nccl', 'gloo'
    dtype = config["dtype"]  # Data type, e.g., 'float16', 'float32'
    compile = config["compile"]  # Whether to use PyTorch 2.0 compilation acceleration

    
    batch_size = config['batch_size']
    
    # device = config["device"]  #
    # order = config['order']
    
    train_data_path_list=config['train_data']
    test_data_path_list=config['test_data']
    # -----------------------------------------------------------------------------
    config_keys = [
        k
        for k, v in globals().items()
        if not k.startswith("_") and isinstance(v, (int, float, bool, str))
    ]
    
    save_dir = os.path.join(out_dir , args.batch_num)
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    logger = get_logger(os.path.join(save_dir,f'InferenceLog-Precompute.log'))
    
    
    # various inits, derived attributes, I/O setup
    ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
    
    if ddp:
        # Check if the operating system is Windows
        if os.name == 'nt':
            # Diff between backends: https://pytorch.org/docs/stable/distributed.html
            init_process_group(backend="gloo")
        else:
            # If the operating system is Linux based, os.name == 'posix'
            init_process_group(backend="nccl")
        ddp_rank = int(os.environ["RANK"])
        ddp_local_rank = int(os.environ["LOCAL_RANK"])
        ddp_world_size = int(os.environ["WORLD_SIZE"])
        device = f"cuda:{ddp_local_rank}"
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
        seed_offset = ddp_rank  # each process gets a different seed

    else:
        # if not ddp, we are running on a single gpu, and one process
        master_process = True
        seed_offset = 0
        ddp_world_size = 1

    if master_process:
        os.makedirs(out_dir, exist_ok=True)
    torch.manual_seed(1337 + seed_offset)
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    device_type = "cuda" if "cuda" in device else "cpu"  # for later use in torch.autocast
    # note: float16 data type will automatically use a GradScaler
    ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
    ctx = (
        nullcontext()
        if device_type == "cpu"
        else torch.cuda.amp.autocast()
    )
    #
    best_val_loss = 1e9
    #
    #-----init dataloader------
    train_ds = PretrainDataset(train_data_path_list, max_length=max_seq_len, memmap=True)
    if ddp:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        pin_memory=False,
        drop_last=False,
        shuffle=False,        
        num_workers=0 if os.name == 'nt' else 4,
        sampler=train_sampler
    )


    init_out_dir = os.path.join(init_out_dir, f'{args.batch_num}')
    model_temp = init_model(os.path.join(init_out_dir, 'initial_model.pth'))
    construct_random_matrix(model_temp)
    random_matrix = torch.load(os.path.join(out_dir, 'random_matrix.pt'), map_location=device)


    # training loop
    iter_per_epoch=len(train_loader)
    # Modify train_loader to make it can be divided by args.bach_num.
    new_iter_per_epoch = (iter_per_epoch // int(args.batch_num)) * int(args.batch_num)
    train_batches = list(train_loader)  # Convert DataLoader to a list
    temp_loader_1 = train_batches[:new_iter_per_epoch]
    new_loader = temp_loader_1


    print('train loader length is ' + str(len(new_loader)))
    
    iter_per_epoch = new_iter_per_epoch
    gradient_accumulation_steps = new_iter_per_epoch // int(args.batch_num)
    
    print('gradient_accumulation_steps is ' + str(gradient_accumulation_steps))
    tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
    if master_process:
        print(f"tokens per iteration will be: {tokens_per_iter:,}")
        print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")

    print('large batch size is ' + str(batch_size * gradient_accumulation_steps))
    print('The number of large-batch is ' + args.batch_num)
    
    for epoch in range(max_epoch):
        additional_time = compute_gradient(epoch)

    if ddp:
        destroy_process_group()
            
    
    end_run_time = time.time()
    
        
    
