import os
import argparse
import time
import math
import random
import pickle
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group, reduce, ReduceOp
from modelling_llama_new import build_llama_models

import datasets
import datasets.distributed
from transformers import (
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
)
from torch.utils.data import DataLoader
from itertools import cycle


base_path = '/anonymous/'
dir_path = '/anonymous/datasets'
if not os.path.exists(base_path):
    base_path = '/mnt/anonymous/' 
    dir_path = os.path.join(base_path, 'anonymous/dataset')
    ckpt_path = os.path.join(base_path, 'anonymous/checkpoints')
else:
    ckpt_path = os.path.join(base_path, 'anonymous/checkpoints')
os.makedirs(ckpt_path, exist_ok=True)

print("defaulting to vocab_size to 32100")

# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
device = torch.device("cuda")
# examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' # 'float32', 'bfloat16'


torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if dtype == 'float32' else torch.amp.autocast(device_type='cuda', dtype=ptdtype, cache_enabled=True)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss(model, eval_iters, batch_size):
    out = {}
    model.eval()
    losses = torch.zeros(eval_iters, device=device)
    for k in range(eval_iters):
        batch_data = next(val_data)['input_ids']
        X = batch_data[:,:-1]
        Y = batch_data[:,1:]
        X, Y = X.to(device, non_blocking=True), Y.to(device, non_blocking=True)
        logits = model(X).logits.view(-1, 32100)
        loss = F.cross_entropy(logits, Y.view(-1), ignore_index=0)
        losses[k] = loss.item()
    out = losses.mean()
    reduce(out, 0, ReduceOp.AVG)
    model.train()
    return out.item()

# learning rate decay scheduler (cosine with warmup)
def get_lr(it, min_lr, max_lr, warmup_iters, max_iters, alpha=1.0):
    # 1) linear warmup for warmup_iters steps
    max_lr *= alpha
    warmup_iters *= alpha
    if it <= warmup_iters:
        return max_lr * it / warmup_iters
    else:
        return max_lr
    
class SplitDataLoader:
    def __init__(self, data_loader, split=1):
        self.data_loader = data_loader
        self.split = split
        self.current_data = None
        self.current_iter = 0

    def __next__(self):
        if self.current_data is None:
            self.current_data = next(self.data_loader)['input_ids'].chunk(self.split)
        data = self.current_data[self.current_iter]
        self.current_iter += 1
        if self.current_iter == self.split:
            self.current_data = None
            self.current_iter = 0

        return data

def train(args):
    # wandb logging
    wandb_log = args.wandb_log
    wandb_project = args.wandb_project
    wandb_run_name = args.wandb_run_name

    gradient_accumulation_steps = args.grad_micro_steps # used to simulate larger batch sizes
    batch_size = args.batch_size # if gradient_accumulation_steps > 1, this is the micro-batch size
    total_batch_size = args.total_bs
    # learning rate decay settings
    decay_lr = True # whether to decay the learning rate
    min_lr = args.max_lr / 10 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
    max_iters = args.max_iters # total number of training iterations
    warmup_iters = args.warmup_iters
    beta1 = args.beta1
    beta2 = args.beta2
    eps = args.eps
    bs_schedule = args.bs_schedule
    weight_decay = args.weight_decay
    grad_clip = args.grad_clip
    log_interval = args.log_interval
    eval_interval = args.eval_interval
    save_interval = args.save_interval
    eval_iters = args.eval_iters
    resume_from_checkpoint = args.resume_from_checkpoint 
    model_name = args.model_name
    n_layer = args.n_layer
    norm_type = args.norm_type
    max_lr = args.max_lr
    

    d_input = 32100
    model = build_llama_models(model_name, d_input, block_size, device)
    print(sum(p.numel() for p in model.parameters()), ddp_rank, ddp_local_rank, world_size)

    ckpt_default_name = f'{wandb_project}_{wandb_run_name}_{model_name}_{bs_schedule}'

    if resume_from_checkpoint == 'auto':
        if os.path.exists(f'{ckpt_path}/{ckpt_default_name}_ckpt.pt'):
            resume_from_checkpoint = ckpt_default_name
        else:
            resume_from_checkpoint = None
    if resume_from_checkpoint:
        ckpt = torch.load(f'{ckpt_path}/{resume_from_checkpoint}_ckpt.pt', map_location=device)
        print(f"recover from step {ckpt['iter_num']}")
        model.load_state_dict(ckpt['model'])
    if ddp:
        model = DDP(model, device_ids=[ddp_local_rank])
    if args.faster_path:
        model = torch.compile(model)

    extra_args = {}
    if args.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW
        if args.faster_path:
            extra_args = {'fused': True}
    elif args.optimizer == 'Lion':
        from lion import Lion 
        optimizer = Lion
    optimizer = optimizer(model.parameters(), lr=max_lr, betas=(beta1,beta2), weight_decay=weight_decay, **extra_args)

    # for group in optimizer.param_groups:
    #     for p in group['params']:
    #         optimizer.state[p]['gradsign'] = torch.zeros_like(p, requires_grad=False).to(p)
        
    # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
    iter_num = 0
    best_val_loss = 1e9

    if resume_from_checkpoint:
        optimizer.load_state_dict(ckpt['optimizer'])
        iter_num = ckpt['iter_num']
        del ckpt
        print("resuming dataloader....")
        for _ in range(iter_num * gradient_accumulation_steps):
            next(train_data)
        for _ in range((iter_num // eval_interval) * eval_iters):
            next(val_data)
        print("resuming done....")

    # logging
    if wandb_log and master_process:
        import wandb_wrapper as wandb
        wandb.init(project=wandb_project, name=ckpt_default_name)
        config = wandb.config 
        config.total_batch_size = total_batch_size 
        config.batch_size = batch_size
        config.gradient_accumulation_steps = gradient_accumulation_steps 
        config.max_iters = max_iters
        config.warmup_iters = warmup_iters  
        config.max_lr = max_lr
        config.beta1 = beta1
        config.beta2 = beta2
        config.eps = eps 
        config.weight_decay = weight_decay
        config.seed = args.seed
        config.log_interval = log_interval
        config.eval_interval = eval_interval
        config.save_interval = save_interval
        config.eval_iters = eval_iters
        config.grad_clip = grad_clip
        config.norm_type = norm_type

    t0 = time.time()
    losses = {"train/loss": [], "val/loss": [], "train/signchange": [],
              "train/paramnorm": [], "train/gradnorm": [], "train/cosine": [], "train/cosine2": [], "train/river": [], "train/river2": [],
              "train/iterval": log_interval, "val/iterval": eval_interval}

    if bs_schedule == 'early':
        switch_iter_1 = int(max_iters * 0.05)
        switch_iter_2 = int(max_iters * 0.1)
    elif bs_schedule == 'middle':
        switch_iter_1 = int(max_iters * 0.3)
        switch_iter_2 = int(max_iters * 0.35)        
    elif bs_schedule == 'late':
        switch_iter_1 = int(max_iters * 0.55)
        switch_iter_2 = int(max_iters * 0.6)   
    else:
        switch_iter_1 = max_iters + 1
        switch_iter_2 = max_iters + 2

    train_data.split = 4    
    if bs_schedule == 'large':
        train_data.split = 1
    while True:
        # determine and set the learning rate for this iteration
        lr = 0
        for param_group in optimizer.param_groups:
            lr = get_lr(iter_num, min_lr, max_lr, warmup_iters, max_iters) if decay_lr else max_lr
            param_group['lr'] = lr

        total_loss = 0

        if iter_num == switch_iter_1:
            train_data.split = 2
        if iter_num == switch_iter_2:
            train_data.split = 1

        # forward backward update, with optional gradient accumulation to simulate larger batch size
        for _ in range(train_data.split):
            for micro_step in range(gradient_accumulation_steps):
                batch_data = next(train_data)
                X = batch_data[:, :-1]
                Y = batch_data[:, 1:]
                # if master_process:
                #     print(iter_num, X.shape)
                X, Y = X.to(device, non_blocking=True), Y.to(device, non_blocking=True)
                if ddp:
                    model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
                with ctx:
                    logits = model(X).logits.view(-1, d_input)
                loss = F.cross_entropy(logits, Y.view(-1), ignore_index=0) # more indent
                (loss / gradient_accumulation_steps).backward()
                total_loss += loss.detach().float() / gradient_accumulation_steps

            # gradient clip
            if grad_clip != 0.0:
               torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        total_loss = total_loss / train_data.split 

        if iter_num % log_interval == 0 and iter_num > 5:
            lossf = total_loss.item()
            with torch.no_grad():
                total_param_norm = torch.norm(torch.stack(torch._foreach_norm(list(model.parameters()), 2)), 2)

        # timing and logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        if iter_num % eval_interval == 0 and iter_num > 5 or iter_num == max_iters:
            loss_val = estimate_loss(model, eval_iters, batch_size)

        if iter_num % log_interval == 0 and iter_num > 5:
            reduce(total_loss, 0, ReduceOp.AVG)
            reduce(total_param_norm, 0, ReduceOp.AVG)

        if iter_num % log_interval == 0 and iter_num > 5 and master_process:
            lossf = total_loss.item()
            losses["train/loss"].append(lossf)
            losses["train/paramnorm"].append(total_param_norm.item())

            if (iter_num % eval_interval == 0 and iter_num > 5) or iter_num == max_iters:
                losses["val/loss"].append(loss_val)
                print(f"iter {iter_num}: train loss {lossf:.4f}, val loss {loss_val:.4f}, time {dt*1000:.2f}ms")
                if wandb_log:
                    wandb.log({
                        "iter": iter_num,
                        "train/loss": lossf,
                        "val/loss": loss_val,
                        "param_norm": total_param_norm,
                        "lr": lr
                    }, step=iter_num)
            else:
                print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
                if wandb_log:
                    wandb.log({
                        "iter": iter_num,
                        "train/loss": lossf,
                        "param_norm": total_param_norm,
                        "lr": lr
                    }, step=iter_num)
            
            if iter_num % save_interval == 0 and iter_num > 5:
                checkpoint = {
                    'model': model.module.state_dict() if ddp else model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'iter_num': iter_num + 1,
                }
                print(f"saving checkpoint ... ", end='')
                torch.save(checkpoint, f'{ckpt_path}/{ckpt_default_name}_tmpckpt.pt')
                os.rename(f'{ckpt_path}/{ckpt_default_name}_tmpckpt.pt', f'{ckpt_path}/{ckpt_default_name}_ckpt.pt')
                print(f"saving logs ... ", end='')
                with open(f"{ckpt_path}/{ckpt_default_name}_ana.p", "wb") as f:
                    pickle.dump(losses, f)
                print(f'saved')

        iter_num += 1

        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        # termination conditions
        if iter_num > max_iters:
            break

    if master_process:
        with open(f"{ckpt_path}/{ckpt_default_name}_finished_ana.p", "wb") as f:
            pickle.dump(losses, f)
            print(f'saved')
        if wandb_log:
            wandb.finish()
    if ddp:
        destroy_process_group()


if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb_log", action='store_true', help="Use Wandb Log.")
    parser.add_argument("--wandb_project", default= 'llama_c4', type=str, help="Wandb project.")
    parser.add_argument("--wandb_run_name", default='moving_4_01' , type=str, help="Wandb run name.")
    parser.add_argument("--model_name", default="0.13B", type=str)
    parser.add_argument("--n_layer", default=12, type=int, help="model depth.")
    parser.add_argument("--seed", default=41, type=int, help="Random seed.")
    parser.add_argument("--batch_size", default=15, type=int, help="Batch size.")
    parser.add_argument("--grad_micro_steps", default=10, type=int, help="Gradient accumulation steps.")
    parser.add_argument("--total_bs", default=300, type=int, help="Total batch size.")
    parser.add_argument("--log_interval", default=20, type=int, help="Log iterations.")
    parser.add_argument("--eval_interval", default=200, type=int, help="..")
    parser.add_argument("--save_interval", default=2000, type=int, help="..")
    parser.add_argument("--eval_iters", default=100, type=int, help="...")
    parser.add_argument("--max_lr", default=6e-4, type=float, help="max lr in AdamW.")
    parser.add_argument("--softcapping", default=0., type=float)
    parser.add_argument("--norm_type", default="pre_norm", choices=["pre_norm", "post_norm", "npost_norm", "rotate_norm", "nrotate_norm"])
    parser.add_argument("--max_iters", default=None, type=int, help="max iterations.")
    parser.add_argument("--warmup_iters", default=1000, type=int, help="warmup iterations.")
    parser.add_argument("--optimizer", default="AdamW")
    parser.add_argument("--dataset", default='c4', type=str, choices=['c4'])
    parser.add_argument("--faster_path", action="store_true")
    parser.add_argument("--switch_iters", default=None, type=int, help="warning: only kept for compatibility")
    parser.add_argument("--beta1", default=None, type=float, help="beta1 in AdamW.")
    parser.add_argument("--beta2", default=None, type=float, help="beta2 in AdamW.")
    parser.add_argument("--eps", default=1e-8, type=float, help="epsilon in AdamW.")
    parser.add_argument("--use_bs_switch", action="store_true")
    parser.add_argument("--bs_switched_iter", default=50000, type=int, help="grad power in AdamWpower.")
    parser.add_argument("--bs_switched_value", default=1024, type=int, help="grad power in AdamWpower.")    
    parser.add_argument("--bs_schedule", default='early', choices=['early', 'late', 'middle', 'small', 'large'])
    parser.add_argument("--workspace", default='llm', choices=['llm', 'moe'])
    parser.add_argument("--weight_decay", default=0.1, type=float, help="weight decay in AdamW.")
    parser.add_argument("--grad_clip", default=1.0, type=float, help="grad clip in AdamW.")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--local-rank", default=-1, type=int)
    parser.add_argument("--gpu_count", type=int, default=1, help="")
    parser.add_argument('--dist_url', type=str, default="")
    parser.add_argument
    args = parser.parse_args()
    if args.switch_iters is not None:
        import warnings
        warnings.warn("switch_iters is only kept for compatibility reason!")

    if args.optimizer == 'AdamW':
        args.beta1 = args.beta1 or 0.9
        args.beta2 = args.beta2 or 0.95
        args.weight_decay = 0.1 if args.weight_decay is None else args.weight_decay 
    elif args.optimizer == 'Lion':
        args.beta1 = args.beta1 or 0.95
        args.beta2 = args.beta2 or 0.98
        args.weight_decay = 1. if args.weight_decay is None else args.weight_decay 
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
        
    global ddp, ddp_rank, ddp_local_rank, master_process, world_size
    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?

    if ddp:
        init_process_group(backend=backend)
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = args.local_rank if args.local_rank != -1 else int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        device = f'cuda:{ddp_local_rank}'
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
    else:
        # if not ddp, we are running on a single gpu, and one process
        ddp_rank = 0                             #ddp_rank is used in get_batch function so this has to be here also when running locally
        master_process = True
        world_size = 1

    global block_size, train_data, val_data
    block_size = 256
    args.max_iters = args.max_iters or 30000 # 50000 -> 30000 
    args.warmup_iters = args.warmup_iters or 600 # 10000 -> 600

    dataset = datasets.load_dataset(f"{dir_path}/c4", streaming=True)
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=f"{dir_path}/t5-tokenizer/tokenizer.json")
    tokenizer.pad_token_id = 0
    tokenizer.eos_token_id = 1

    def tokenize_fun(data):
        output = tokenizer(data["text"], truncation=True, max_length=block_size+1, padding=False)
        return output

    tokenized_data = dataset.map(tokenize_fun, batched=True, remove_columns=["text", "url", "timestamp"])
    tokenized_train_data = datasets.distributed.split_dataset_by_node(tokenized_data['train'], rank=ddp_rank, world_size=world_size)
    tokenized_val_data = datasets.distributed.split_dataset_by_node(tokenized_data['validation'], rank=ddp_rank, world_size=world_size)
    collate_fn = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    train_dataloader = DataLoader(tokenized_train_data, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, pin_memory_device=device)
    val_dataloader = DataLoader(tokenized_val_data, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, pin_memory_device=device)
    # num_workers=4
    train_data = SplitDataLoader(iter(train_dataloader))
    val_data = cycle(val_dataloader)

    setup_seed(args.seed)
    train(args)

# python3 -m torch.distributed.run --standalone --nproc_per_node=4 train_llama_c4.py --batch_size=8 --grad_micro_steps=16 --total_bs=512 --max_lr=3e-3 --weight_decay=0.1 --warmup_iters=500 --max_iters=1000 --model_name=0.13B --n_layer=12 --save_interval=100 --wandb_run_name=debug_llama_C4_3e3_500_1k

