import os
from random import sample
import time
from contextlib import nullcontext

import torch
from model import Transformer, ModelArgs
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
from dataset import PretrainDataset
import logging
import json
import argparse
from sample_order import sample
import itertools

def output_performance(test_loader):
    model.eval()
    eval_loss = torch.zeros(1, device=device)
    num_eval_example = 0
    perplexity_single = torch.zeros(1, device=device)
    for step, (X, Y) in enumerate(test_loader):
        X=X.to(device)
        Y=Y.to(device)       
        
        with torch.no_grad(): 
            logits = model(X, Y)
            
        eval_loss += model.last_loss.sum().item()
        num_eval_example += X.size(0)
        mini_perplexity = torch.exp(model.last_loss.sum() / X.size(0))
        perplexity_single += mini_perplexity


    eval_loss = eval_loss / num_eval_example
    perplexity_all = torch.exp(eval_loss)
    
    perplexity_all = perplexity_all.item()
    perplexity_single = perplexity_single.item() / len(test_loader)
    
    print(f'Performance (Perplexity_All) is {perplexity_all:.7f}')
    print(f'Performance (Perplexity_single) is {perplexity_single:.7f}')

    return perplexity_all, perplexity_single
                
#To run with DDP on 4 gpus on 1 node, example:
# torchrun --standalone --nproc_per_node=4 pretrain.py OR python -m torch.distributed.launch --nproc_per_node=4 pretrain.py        
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "a")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


def train_epoch(epoch, order_idx):
    start_time=time.time()
    
    print('The len of batches after cutting is ' + str(len(new_loader)) + 
          ", which is equal to " + str(gradient_accumulation_steps) + str(' * ') + args.batch_num)
    
    for step, (X, Y) in enumerate(new_loader):
        X=X.to(device)
        Y=Y.to(device)
             
        
        lr = learning_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        if ddp:

            model.require_backward_grad_sync = 0 == gradient_accumulation_steps - 1
        with torch.amp.autocast('cuda', enabled= (dtype == 'float16')):
            logits = model(X, Y)
            loss = model.last_loss
        loss.backward()
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            torch.cuda.empty_cache()

        
        if step % log_interval == 0:
            spend_time=time.time()-start_time
            logger.info(
                    'Model:{}/{} Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                        order_idx,
                        10 - 1,
                        epoch,
                        max_epoch, 
                        step, 
                        iter_per_epoch - 1,
                        loss.item(), 
                        optimizer.param_groups[-1]['lr'],
                        spend_time / (step+1) * iter_per_epoch // 60 - spend_time // 60))

        if order_idx == 0:
            save_ckpt = os.path.join(save_dir, 'checkpoints')
            if (step + 1) % gradient_accumulation_steps == 0:
                model.eval()
                torch.save(model.state_dict(),'{}/model_batch_{}.pth'.format(save_ckpt,int((step + 1) // gradient_accumulation_steps)))
                model.train()
    

def init_model():
    # model init
    model_args = dict(
        dim=dim,
        n_layers=n_layers,
        n_heads=n_heads,
        n_kv_heads=n_heads,
        vocab_size=64793,
        multiple_of=multiple_of,
        max_seq_len=max_seq_len,
        dropout=dropout,
    )  # start with model_args from command line
    if init_from == "scratch":
        # init a new model from scratch
        print("Initializing a new model from scratch")
        gptconf = ModelArgs(**model_args)
        model = Transformer(gptconf)
        
    
    elif init_from == "resume":
        print(f"Resuming training from {out_dir}")
        # resume training from a checkpoint.
        ckpt_path = os.path.join(out_dir, "initial_model.pth")
        state_dict = torch.load(ckpt_path, map_location=device)

        gptconf = ModelArgs(**model_args)
        model = Transformer(gptconf)

        unwanted_prefix = "_orig_mod."
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
        model.load_state_dict(state_dict)
    return model


# I/O
def load_config(config_path='configs/config.json'):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Counterfacutal LLM')
    parser.add_argument('--batch_num', type=str, default='8', 
                    help='the number of large batchs for true model')
    parser.add_argument('--device', type=str, default="cuda:0", 
                    help='the device number')
    parser.add_argument('--seed', type=int, default=1337)
    parser.add_argument('--lr', type=float, default=3e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-1)
    parser.add_argument('--dropout', type=float, default=0.4)

    args = parser.parse_args()

    
    for order_idx, order in enumerate(sample[args.batch_num]): 
        begin_run_time = time.time()
        device = args.device
        
        config = load_config()
        
        out_dir = config["pretrain_out_dir"]  # Output directory
        max_epoch = config["max_epoch"]  # Maximum number of training epochs
        eval_interval = config["eval_interval"]  # Evaluation interval (in number of epochs)
        log_interval = config["log_interval"]  # Logging interval (in number of batches)
        eval_iters = config["eval_iters"]  # Number of iterations used per evaluation
        eval_only = config["eval_only"]  # If True, the script exits after the first evaluation
        init_from = config["init_from"]  # Initialization method: 'scratch', 'resume', or 'gpt2*'
        max_seq_len = config["max_seq_len"]  # Maximum sequence length
        dim = config["dim"]  # Embedding dimension of the model
        n_layers = config["n_layers"]  # Number of Transformer layers
        n_heads = config["n_heads"]  # Number of attention heads in the Transformer
        multiple_of = config["multiple_of"]  # Sequence length must be a multiple of this value
        dropout = config["dropout"]  # Dropout probability
        bias = config["bias"]  # Whether to use bias in linear layers
        learning_rate = config["learning_rate"]  # Learning rate
        weight_decay = config["weight_decay"]  # Weight decay coefficient
        beta1 = config["beta1"]  # Beta1 parameter for the Adam optimizer
        beta2 = config["beta2"]  # Beta2 parameter for the Adam optimizer
        grad_clip = config["grad_clip"]  # Gradient clipping threshold
        decay_lr = config["decay_lr"]  # Whether to apply learning rate decay
        backend = config["backend"]  # Backend for distributed training, e.g., 'nccl' or 'gloo'
        dtype = config["dtype"]  # Data type, e.g., 'float16' or 'float32'
        compile = config["compile"]  # Whether to enable PyTorch 2.0 compile acceleration

        batch_size = config['batch_size']
        

        learning_rate = args.lr
        dropout = args.dropout
        weight_decay = args.weight_decay
        
        
        train_data_path_list=config['train_data']
        test_data_path_list=config['test_data']
        # -----------------------------------------------------------------------------
        config_keys = [
            k
            for k, v in globals().items()
            if not k.startswith("_") and isinstance(v, (int, float, bool, str))
        ]

        save_dir = os.path.join(out_dir , args.batch_num)
        if not os.path.exists(save_dir): os.makedirs(save_dir)
        logger = get_logger(os.path.join(save_dir,f'TrainLog-{order_idx}.log'))
        
        # various inits, derived attributes, I/O setup
        ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
        
        if ddp:
            # Check if the operating system is Windows
            if os.name == 'nt':
                # Diff between backends: https://pytorch.org/docs/stable/distributed.html
                init_process_group(backend="gloo")
            else:
                # If the operating system is Linux based, os.name == 'posix'
                init_process_group(backend="nccl")
            ddp_rank = int(os.environ["RANK"])
            ddp_local_rank = int(os.environ["LOCAL_RANK"])
            ddp_world_size = int(os.environ["WORLD_SIZE"])
            device = f"cuda:{ddp_local_rank}"
            torch.cuda.set_device(device)
            master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
            seed_offset = ddp_rank  # each process gets a different seed

        else:
            # if not ddp, we are running on a single gpu, and one process
            master_process = True
            seed_offset = 0
            ddp_world_size = 1

        if master_process:
            os.makedirs(out_dir, exist_ok=True)
        torch.manual_seed(args.seed + seed_offset)
        torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
        torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
        device_type = "cuda" if "cuda" in device else "cpu"  # for later use in torch.autocast
        # note: float16 data type will automatically use a GradScaler
        ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
        ctx = (
            nullcontext()
            if device_type == "cpu"
            else torch.cuda.amp.autocast()
        )
        #
        best_val_loss = 1e9
        #
        #-----init dataloader------
        
        train_ds = PretrainDataset(train_data_path_list, max_length=max_seq_len, memmap=True)
        if ddp:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
        else:
            train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=batch_size,
            pin_memory=False,
            drop_last=False,
            shuffle=False,        
            num_workers=0 if os.name == 'nt' else 4,
            sampler=train_sampler
        )
        
        test_ds = PretrainDataset(test_data_path_list, max_length=256)
        test_loader = torch.utils.data.DataLoader(
            test_ds,
            batch_size=batch_size,
            pin_memory=False,
            drop_last=False,
            shuffle=False,        
            num_workers=0,
        )
        
        
        # init model
        model=init_model()
        model.to(device)
        torch.save(model.state_dict(), os.path.join(out_dir, 'initial_model.pth'))
        print('Initial Model Successsfully Store!')
        optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
        
        # compile the model
        if compile:
            print("compiling the model... (takes a ~minute)")
            unoptimized_model = model
            model = torch.compile(model) # requires PyTorch 2.0
        # wrap model into DDP container
        if ddp:
            # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
            # construction time since NCCL does not support `ComplexFloat`
            prefix = "_orig_mod." if compile else ""
            model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
            model = DDP(model, device_ids=[ddp_local_rank])
            #
        raw_model = model.module if ddp else model # unwrap DDP container if needed
        # training loop
        iter_per_epoch=len(train_loader)
        # Modify train_loader to make it can be divided by args.bach_num.
        new_iter_per_epoch = (iter_per_epoch // int(args.batch_num)) * int(args.batch_num)
        train_batches = list(train_loader)  # Convert DataLoader to a list
        temp_loader_1 = train_batches[:new_iter_per_epoch]
        train_loader = temp_loader_1
        
        l_batch = iter_per_epoch // int(args.batch_num)
        large_batches = []
        current_batch = []
        for i, batch in enumerate(train_loader):
            current_batch.append(batch)
            if len(current_batch) == l_batch:
                large_batches.append(current_batch)
                current_batch = [] 


        ordered_large_batches = [large_batches[i] for i in order]
        new_loader = list(itertools.chain(*ordered_large_batches))

        
        print('train loader length is ' + str(len(new_loader)))
        
        iter_per_epoch = new_iter_per_epoch
        gradient_accumulation_steps = new_iter_per_epoch // int(args.batch_num)
        
        print('gradient_accumulation_steps is ' + str(gradient_accumulation_steps))
        tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
        if master_process:
            print(f"tokens per iteration will be: {tokens_per_iter:,}")
            print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")

        print('large batch size is ' + str(batch_size * gradient_accumulation_steps))
        print('The number of large-batch is ' + args.batch_num)
        # Max_epoch = 1
        for epoch in range(max_epoch):
            train_epoch(epoch, order_idx)
            if order_idx != 0:
                if ddp:
                    if torch.distributed.get_rank() == 0: 
                        torch.save(model.state_dict(),'{}/model_{}.pth'.format(save_dir, order_idx))
                else:
                    torch.save(model.state_dict(),'{}/model_{}.pth'.format(save_dir, order_idx))
        if ddp:
            destroy_process_group()
            
        
        perplexity_all, perplexity_single = output_performance(test_loader)
                

        del model
        del optimizer
        torch.cuda.empty_cache()
            
        
    
