import os
import argparse
import time
import math
import random
import pickle
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group, reduce, ReduceOp
from modelling_llama_moe_new import build_llama_models


import datasets
import datasets.distributed
from transformers import (
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
)
from torch.utils.data import DataLoader
from itertools import cycle


base_path = '/anonymous/'
dir_path = '/anonymous/datasets'
if not os.path.exists(base_path):
    base_path = '/mnt/anonymous/' 
    dir_path = os.path.join(base_path, 'anonymous/dataset')
ckpt_path = os.path.join(base_path, 'anonymous/checkpoints')

print("defaulting to vocab_size to 32100")

# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
device = torch.device("cuda")
# examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' # 'float32', 'bfloat16'


torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if dtype == 'float32' else torch.amp.autocast(device_type='cuda', dtype=ptdtype, cache_enabled=True)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss(model, eval_iters, batch_size):
    out = {}
    model.eval()
    losses, aux_losses = (torch.zeros(eval_iters, device=device) for _ in range(2))
    for k in range(eval_iters):
        batch_data = next(val_data)['input_ids']
        X = batch_data[:,:-1]
        Y = batch_data[:,1:]
        X, Y = X.to(device, non_blocking=True), Y.to(device, non_blocking=True)
        aux_loss, logits, *_ = model(X, output_router_logits=True).values()
        loss = F.cross_entropy(logits.view(-1, 32100), Y.view(-1), ignore_index=-1) + 0.01 * aux_loss # more indent
        loss -= aux_loss * 0.01
        losses[k], aux_losses[k] = loss.item(), aux_loss.item()
    out, out2 = losses.mean(), aux_losses.mean()
    reduce(out, 0, ReduceOp.AVG)
    reduce(out2, 0, ReduceOp.AVG)
    model.train()
    return out.item(), out2.item()

# learning rate decay scheduler (cosine with warmup)
def get_lr(it, min_lr, max_lr, warmup_iters, max_iters, alpha=1.0):
    # 1) linear warmup for warmup_iters steps
    max_lr *= alpha
    warmup_iters *= alpha
    if it <= warmup_iters:
        return max_lr * it / warmup_iters
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    lr = min_lr + coeff * (max_lr - min_lr)
    return lr


def train(args):

    # wandb logging
    wandb_log = args.wandb_log
    wandb_project = args.wandb_project
    wandb_run_name = args.wandb_run_name

    gradient_accumulation_steps = args.grad_micro_steps # used to simulate larger batch sizes
    batch_size = args.batch_size # if gradient_accumulation_steps > 1, this is the micro-batch size
    total_batch_size = args.total_bs
    # learning rate decay settings
    decay_lr = True # whether to decay the learning rate
    max_embed_lr =args.max_embed_lr # max learning rate
    max_head_lr =args.max_head_lr # max learning rate
    max_ln_lr =args.max_ln_lr # max learning rate
    max_qk_lr =args.max_qk_lr # max learning rate
    max_vo_lr =args.max_vo_lr # max learning rate
    max_shared_lr =args.max_shared_lr # max learning rate
    max_experts_lr =args.max_experts_lr # max learning rate
    max_gated_lr =args.max_gated_lr # max learning rate
    embed_alpha =args.embed_alpha # alpha in scheduler
    head_alpha =args.head_alpha # alpha in scheduler
    ln_alpha =args.ln_alpha # alpha in scheduler
    qk_alpha =args.qk_alpha # alpha in scheduler
    vo_alpha =args.vo_alpha # alpha in scheduler
    shared_alpha =args.shared_alpha # alpha in scheduler
    experts_alpha =args.experts_alpha # alpha in scheduler
    gated_alpha =args.gated_alpha # alpha in scheduler
    embed_wd =args.embed_wd # alpha in scheduler
    head_wd =args.head_wd # alpha in scheduler
    ln_wd =args.ln_wd # alpha in scheduler
    qk_wd =args.qk_wd # alpha in scheduler
    vo_wd =args.vo_wd # alpha in scheduler
    shared_wd =args.shared_wd # alpha in scheduler
    experts_wd =args.experts_wd # alpha in scheduler
    gated_wd =args.gated_wd # alpha in scheduler
    min_lr = args.max_lr / 20 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
    max_iters = args.max_iters # total number of training iterations
    warmup_iters = args.warmup_iters
    beta1 = args.beta1
    beta2 = args.beta2
    eps = args.eps
    use_gradpower = args.use_gradpower
    gradpower = args.gradpower
    weight_decay = args.weight_decay
    grad_clip = args.grad_clip
    log_interval = args.log_interval
    eval_interval = args.eval_interval
    save_interval = args.save_interval
    eval_iters = args.eval_iters
    resume_from_checkpoint = args.resume_from_checkpoint 
    model_name = args.model_name
    softcapping = args.softcapping
    norm_type = args.norm_type
    mask_interval = args.mask_interval
    max_lr = args.max_lr


    d_input = 32100
    model = build_llama_models(model_name, d_input, block_size, device) 
    print(sum(p.numel() for p in model.parameters()), ddp_rank, ddp_local_rank, world_size)

    if use_gradpower:
        ckpt_default_name = f'{wandb_project}_{wandb_run_name}_{model_name}_{max_lr}_{gradpower}'
    else:
        ckpt_default_name = f'{wandb_project}_{wandb_run_name}_{model_name}_{max_lr}'
        
    if resume_from_checkpoint == 'auto':
        if os.path.exists(f'{ckpt_path}/{ckpt_default_name}_ckpt.pt'):
            resume_from_checkpoint = ckpt_default_name
        else:
            resume_from_checkpoint = None
    if resume_from_checkpoint:
        ckpt = torch.load(f'{ckpt_path}/{resume_from_checkpoint}_ckpt.pt', map_location=device)
        print(f"recover from step {ckpt['iter_num']}")
        model.load_state_dict(ckpt['model'])
    if ddp:
        model = DDP(model, device_ids=[ddp_local_rank])
    if args.faster_path:
        model = torch.compile(model, dynamic=False)

    extra_args = {}
    if args.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW
        # if args.faster_path:
        extra_args = {'fused': True}
    elif args.optimizer == 'Lion':
        from lion import Lion 
        optimizer = Lion

    optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, betas=(beta1,beta2), weight_decay=weight_decay, **extra_args)
         
    # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
    iter_num = 0
    best_val_loss = 1e9

    if resume_from_checkpoint:
        optimizer.load_state_dict(ckpt['optimizer'])
        iter_num = ckpt['iter_num']
        del ckpt

    # logging
    if wandb_log and master_process:
        # import wandb
        import wandb
        wandb.init(project=wandb_project, name=wandb_run_name)
        config = wandb.config 
        config.total_batch_size = total_batch_size 
        config.batch_size = batch_size
        config.gradient_accumulation_steps = gradient_accumulation_steps 
        config.max_iters = max_iters
        config.warmup_iters = warmup_iters  
        config.max_embed_lr = max_embed_lr
        config.max_head_lr = max_head_lr
        config.max_ln_lr = max_ln_lr
        config.max_qk_lr = max_qk_lr
        config.max_vo_lr = max_vo_lr
        config.max_shared_lr = max_shared_lr
        config.max_experts_lr = max_experts_lr
        config.max_gated_lr = max_gated_lr
        config.embed_alpha = embed_alpha
        config.head_alpha = head_alpha
        config.ln_alpha = ln_alpha
        config.qk_alpha = qk_alpha
        config.vo_alpha = vo_alpha
        config.shared_alpha = shared_alpha
        config.experts_alpha = experts_alpha
        config.gated_alpha = gated_alpha
        config.embed_wd = embed_wd
        config.head_wd = head_wd
        config.ln_wd = ln_wd
        config.qk_wd = qk_wd
        config.vo_wd = vo_wd
        config.shared_wd = shared_wd
        config.experts_wd = experts_wd
        config.gated_wd = gated_wd
        config.beta1 = beta1
        config.beta2 = beta2
        config.eps = eps 
        config.use_gradpower = use_gradpower 
        config.gradpower = gradpower 
        config.weight_decay = weight_decay
        config.seed = args.seed
        config.log_interval = log_interval
        config.eval_interval = eval_interval
        config.save_interval = save_interval
        config.eval_iters = eval_iters
        config.grad_clip = grad_clip
        config.softcapping = softcapping
        config.norm_type = norm_type

    # if master_process:
    #     for n, p in model.named_parameters():
    #         print(n)

    t0 = time.time()

    losses = {"train/loss": [], "val/loss": [], 
              "train/aux_loss": [], "val/aux_loss": [], 
              "train/iterval": log_interval, "val/iterval": eval_interval}

    while True:
        lr = get_lr(iter_num, min_lr, max_lr, warmup_iters, max_iters) if decay_lr else max_lr

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        total_loss, total_aux_loss = 0, 0
        # forward backward update, with optional gradient accumulation to simulate larger batch size
        for micro_step in range(gradient_accumulation_steps):
            batch_data = next(train_data)['input_ids']
            X = batch_data[:,:-1]
            Y = batch_data[:,1:]
            X, Y = X.to(device, non_blocking=True), Y.to(device, non_blocking=True)
            if ddp:
                model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
            with ctx:
                aux_loss, logits, *_ = model(X, output_router_logits=True).values()
            loss = F.cross_entropy(logits.view(-1, d_input), Y.view(-1), ignore_index=-1) + 0.01 * aux_loss # more indent
            # backward pass, with gradient scaling if training in fp16
            # scaler.scale(loss).backward()
            (loss / gradient_accumulation_steps).backward()
            total_loss += loss.detach().float() / gradient_accumulation_steps
            total_aux_loss += aux_loss.detach().float() / gradient_accumulation_steps
            # step the optimizer and scaler if training in fp16
            # scaler.descent_step(optimizer, lr,max_lr)      
        # gradient clip
        if grad_clip != 0.0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        
        # gradient power
        if use_gradpower:
            for param in model.parameters():
                if param.grad is not None:
                    g = param.grad
                    modified_g = torch.sign(g) * torch.pow(torch.abs(g) + eps, gradpower)
                    param.grad = modified_g
        
        optimizer.step()
        # scaler.update()
        # flush the gradients as soon as we can, no need for this memory anymore
        optimizer.zero_grad(set_to_none=True)

        # timing and logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        if iter_num % eval_interval == 0 or iter_num == max_iters:
            loss_val, aux_loss_val = estimate_loss(model, eval_iters, batch_size)
        if iter_num % log_interval == 0:
            reduce(total_loss, 0, ReduceOp.AVG)
            reduce(total_aux_loss, 0, ReduceOp.AVG)
        if iter_num % log_interval == 0 and master_process:
            lossf = total_loss.item() - 0.01 * total_aux_loss.item()
            aux_lossf = total_aux_loss.item()
            losses["train/loss"].append(lossf)
            losses["train/aux_loss"].append(aux_lossf)

            with torch.no_grad():
                total_param_norm = torch.norm(torch.stack(torch._foreach_norm(list(model.parameters()), 2)), 2).item()

            if iter_num % eval_interval == 0 or iter_num == max_iters:
                # loss_val = estimate_loss(model, eval_iters, batch_size)
                losses["val/loss"].append(loss_val)
                losses["val/aux_loss"].append(aux_loss_val)
                print(f"iter {iter_num}: train loss {lossf:.4f}, train aux loss {aux_lossf:.4f}, val loss {loss_val:.4f}, val aux loss {aux_loss_val:.4f}, time {dt*1000:.2f}ms")
                if wandb_log:
                    wandb.log({
                        "iter": iter_num,
                        "train/loss": lossf,
                        "train/aux_loss": aux_lossf,
                        "val/loss": loss_val,
                        "val/aux_loss": aux_loss_val,
                        "lr": lr,
                        "param_norm": total_param_norm,
                        # "threshold": threshold
                    }, step=iter_num)
            else:
                print(f"iter {iter_num}: loss {lossf:.4f}, train aux loss {aux_lossf:.4f}, time {dt*1000:.2f}ms")
                if wandb_log:
                    wandb.log({
                        "iter": iter_num,
                        "train/loss": lossf,
                        "train/aux_loss": aux_lossf,
                        "lr": lr,
                        "param_norm": total_param_norm,
                        # "threshold": threshold
                    }, step=iter_num)
                # else:
                #     print("no wandb log")
            
            if iter_num % save_interval == 0 and iter_num != 0:
                # checkpoint = {
                #     'model': model.module.state_dict() if ddp else model.state_dict(),
                #     'optimizer': optimizer.state_dict(),
                #     'iter_num': iter_num + 1,
                # }
                # print(f"saving checkpoint ... ", end='')
                # torch.save(checkpoint, f'{ckpt_path}/{ckpt_default_name}_tmpckpt.pt')
                # os.rename(f'{ckpt_path}/{ckpt_default_name}_tmpckpt.pt', f'{ckpt_path}/{ckpt_default_name}_ckpt.pt')
                with open(f"{ckpt_path}/{ckpt_default_name}_ana.p", "wb") as f:
                    pickle.dump(losses, f)
                print(f'saved')

        iter_num += 1

        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        # termination conditions
        if iter_num > max_iters:
            break

    if master_process:
        with open(f"{ckpt_path}/{ckpt_default_name}_finished_ana.p", "wb") as f:
            pickle.dump(losses, f)
            print(f'saved')
        if wandb_log:
            wandb.finish()
    if ddp:
        destroy_process_group()


if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb_log", action='store_true', help="Use Wandb Log.")
    parser.add_argument("--wandb_project", default= 'llama_web_blockwise', type=str, help="Wandb project.")
    parser.add_argument("--wandb_run_name", default='moving_4_01' , type=str, help="Wandb run name.")
    parser.add_argument("--model_name", default="0.23B", type=str)
    parser.add_argument("--seed", default=41, type=int, help="Random seed.")
    parser.add_argument("--batch_size", default=15, type=int, help="Batch size.")
    parser.add_argument("--grad_micro_steps", default=10, type=int, help="Gradient accumulation steps.")
    parser.add_argument("--total_bs", default=300, type=int, help="Total batch size.")
    parser.add_argument("--log_interval", default=20, type=int, help="Log iterations.")
    parser.add_argument("--eval_interval", default=200, type=int, help="..")
    parser.add_argument("--save_interval", default=2000, type=int, help="..")
    parser.add_argument("--mask_interval", default=10, type=int, help="Mask iterations.")
    parser.add_argument("--eval_iters", default=100, type=int, help="...")
    parser.add_argument("--max_lr", default=6e-4, type=float, help="max lr in AdamW.")
    parser.add_argument("--max_embed_lr", default=None, type=float, help="max embed lr in AdamW.")
    parser.add_argument("--max_head_lr", default=None, type=float, help="max head lr in AdamW.")
    parser.add_argument("--max_ln_lr", default=None, type=float, help="max ln lr in AdamW.")
    parser.add_argument("--max_qk_lr", default=None, type=float, help="max qk lr in AdamW.")
    parser.add_argument("--max_vo_lr", default=None, type=float, help="max vo lr in AdamW.")
    parser.add_argument("--max_shared_lr", default=None, type=float, help="max shared lr in AdamW.")
    parser.add_argument("--max_experts_lr", default=None, type=float, help="max experts lr in AdamW.")
    parser.add_argument("--max_gated_lr", default=None, type=float, help="max gated lr in AdamW.")
    parser.add_argument("--embed_alpha", default=1.0, type=float, help="embed alpha in learning rate.")
    parser.add_argument("--head_alpha", default=1.0, type=float, help="head alpha in learning rate.")
    parser.add_argument("--ln_alpha", default=1.0, type=float, help="ln alpha in learning rate.")
    parser.add_argument("--qk_alpha", default=1.0, type=float, help="qk alpha in learning rate.")
    parser.add_argument("--vo_alpha", default=1.0, type=float, help="vo alpha in learning rate.")
    parser.add_argument("--shared_alpha", default=1.0, type=float, help="shared alpha in learning rate.")
    parser.add_argument("--experts_alpha", default=1.0, type=float, help="experts alpha in learning rate.")
    parser.add_argument("--gated_alpha", default=1.0, type=float, help="gated alpha in learning rate.")
    parser.add_argument("--embed_wd", default=None, type=float, help="embed wd in learning rate.")
    parser.add_argument("--head_wd", default=None, type=float, help="head wd in learning rate.")
    parser.add_argument("--ln_wd", default=None, type=float, help="ln wd in learning rate.")
    parser.add_argument("--qk_wd", default=None, type=float, help="qk wd in learning rate.")
    parser.add_argument("--vo_wd", default=None, type=float, help="vo wd in learning rate.")
    parser.add_argument("--shared_wd", default=None, type=float, help="shared alpha in learning rate.")
    parser.add_argument("--experts_wd", default=None, type=float, help="experts alpha in learning rate.")
    parser.add_argument("--gated_wd", default=None, type=float, help="gated alpha in learning rate.")
    parser.add_argument("--softcapping", default=0., type=float)
    parser.add_argument("--norm_type", default="pre-ln", choices=["pre-ln", "post-ln", "npost-ln", "rotate-ln", "nrotate-ln"])
    parser.add_argument("--max_iters", default=None, type=int, help="max iterations.")
    parser.add_argument("--warmup_iters", default=None, type=int, help="warmup iterations.")
    parser.add_argument("--optimizer", default="AdamW")
    parser.add_argument("--dataset", default='openwebtext', type=str, choices=['openwebtext', 'minipile'])
    parser.add_argument("--faster_path", action="store_true")
    parser.add_argument("--switch_iters", default=None, type=int, help="warning: only kept for compatibility")
    parser.add_argument("--beta1", default=None, type=float, help="beta1 in AdamW.")
    parser.add_argument("--beta2", default=None, type=float, help="beta2 in AdamW.")
    parser.add_argument("--eps", default=1e-8, type=float, help="epsilon in AdamW.")
    parser.add_argument("--use_gradpower", action="store_true")
    parser.add_argument("--gradpower", default=1.0, type=float, help="grad power in AdamWpower.")
    parser.add_argument("--workspace", default='llm', choices=['llm', 'moe'])
    parser.add_argument("--weight_decay", default=0.1, type=float, help="weight decay in AdamW.")
    parser.add_argument("--grad_clip", default=1.0, type=float, help="grad clip in AdamW.")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--gpu_count", type=int, default=1, help="")
    parser.add_argument('--dist_url', type=str, default="")
    args = parser.parse_args()
    args.max_embed_lr = args.max_embed_lr or args.max_lr
    args.max_head_lr = args.max_head_lr or args.max_lr
    args.max_ln_lr = args.max_ln_lr or args.max_lr
    args.max_qk_lr = args.max_qk_lr or args.max_lr
    args.max_vo_lr = args.max_vo_lr or args.max_lr
    args.max_shared_lr = args.max_shared_lr or args.max_lr
    args.max_experts_lr = args.max_experts_lr or args.max_lr
    args.max_gated_lr = args.max_gated_lr or args.max_lr
    args.embed_wd = args.embed_wd if args.embed_wd is not None else args.weight_decay
    args.head_wd = args.head_wd if args.head_wd is not None else args.weight_decay
    args.ln_wd = args.ln_wd if args.ln_wd is not None else args.weight_decay
    args.qk_wd = args.qk_wd if args.qk_wd is not None else args.weight_decay
    args.vo_wd = args.vo_wd if args.vo_wd is not None else args.weight_decay
    args.shared_wd = args.shared_wd or args.weight_decay
    args.experts_wd = args.experts_wd or args.weight_decay
    args.gated_wd = args.gated_wd or args.weight_decay
    if args.switch_iters is not None:
        import warnings
        warnings.warn("switch_iters is only kept for compatibility reason!")

    if args.optimizer == 'AdamW':
        args.beta1 = args.beta1 or 0.9
        args.beta2 = args.beta2 or 0.95
        args.weight_decay = 0.1 if args.weight_decay is None else args.weight_decay 
    elif args.optimizer == 'Lion':
        args.beta1 = args.beta1 or 0.95
        args.beta2 = args.beta2 or 0.98
        args.weight_decay = 1. if args.weight_decay is None else args.weight_decay 
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
 
    global ddp, ddp_rank, ddp_local_rank, master_process, world_size
    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?

    if ddp:
        init_process_group(backend=backend)
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = args.local_rank if args.local_rank != -1 else int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        device = f'cuda:{ddp_local_rank}'
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
    else:
        # if not ddp, we are running on a single gpu, and one process
        ddp_rank = 0                             #ddp_rank is used in get_batch function so this has to be here also when running locally
        master_process = True
        world_size = 1

    global block_size, train_data, val_data
    block_size = 256
    args.max_iters = args.max_iters or 30000
    args.warmup_iters = args.warmup_iters or 1000

    dataset = datasets.load_dataset(f"{dir_path}/c4", streaming=True)
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=f"{dir_path}/t5-tokenizer/tokenizer.json")
    tokenizer.pad_token_id = 0
    tokenizer.eos_token_id = 1

    def tokenize_fun(data):
        output = tokenizer(data["text"], truncation=True, max_length=block_size+1, padding=False)
        return output

    tokenized_data = dataset.map(tokenize_fun, batched=True, remove_columns=["text", "url", "timestamp"])
    tokenized_train_data = datasets.distributed.split_dataset_by_node(tokenized_data['train'], rank=ddp_rank, world_size=world_size)
    tokenized_val_data = datasets.distributed.split_dataset_by_node(tokenized_data['validation'], rank=ddp_rank, world_size=world_size)
    collate_fn = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    train_dataloader = DataLoader(tokenized_train_data, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, pin_memory_device=device)
    val_dataloader = DataLoader(tokenized_val_data, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, pin_memory_device=device)
    # num_workers=4
    train_data = iter(train_dataloader)
    val_data = cycle(val_dataloader)

    setup_seed(args.seed)
    train(args)

