import os
import argparse
import time
import math
import random
import pickle
import tempfile
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group, reduce, ReduceOp
from modelling_llama_new import build_llama_models
from transformers import AutoTokenizer
import datasets
import datasets.distributed
from transformers import (
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
)
from torch.utils.data import DataLoader
from itertools import cycle
from typing import Dict




base_path = 'checkpoints'
dir_path = os.path.join(base_path, 'c4')
ckpt_path = os.path.join(base_path, 'checkpoints')


# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
device = torch.device("cuda")
# examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' # 'float32', 'bfloat16'


torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if dtype == 'float32' else torch.amp.autocast(device_type='cuda', dtype=ptdtype, cache_enabled=True)

# save model checkpoint 
def atomic_torch_save(state: dict, final_path: str):
    os.makedirs(os.path.dirname(final_path), exist_ok=True)
    with tempfile.NamedTemporaryFile(dir=os.path.dirname(final_path), delete=False) as tmp:
        tmp_path = tmp.name
    try:
        torch.save(state, tmp_path)
        os.replace(tmp_path, final_path)  
    except Exception:
        try: os.remove(tmp_path)
        except Exception: pass
        raise

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss(model, eval_iters, batch_size):
    out = {}
    model.eval()
    losses = torch.zeros(eval_iters, device=device)
    for k in range(eval_iters):
        batch_data = next(val_data)['input_ids']
        X = batch_data[:,:-1]
        Y = batch_data[:,1:]
        X, Y = X.to(device, non_blocking=True), Y.to(device, non_blocking=True)
        logits = model(X).logits.view(-1, 32100)
        loss = F.cross_entropy(logits, Y.view(-1), ignore_index=0)
        losses[k] = loss.item()
    out = losses.mean()
    reduce(out, 0, ReduceOp.AVG)
    model.train()
    return out.item()

# learning rate decay scheduler (cosine with warmup)
def get_lr(it, min_lr, max_lr, warmup_iters, max_iters, alpha=1.0):
    # 1) linear warmup for warmup_iters steps
    max_lr *= alpha
    if it <= warmup_iters:
        return max_lr * it / warmup_iters
    # 2) cosine decay after the warmup steps  
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    lr = min_lr + coeff * (max_lr - min_lr)
    return lr

def train(args):

    # wandb logging
    wandb_log = args.wandb_log
    wandb_project = args.wandb_project
    wandb_run_name = args.wandb_run_name

    gradient_accumulation_steps = args.grad_micro_steps # used to simulate larger batch sizes
    batch_size = args.batch_size # if gradient_accumulation_steps > 1, this is the micro-batch size
    # learning rate decay settings
    decay_lr = True # whether to decay the learning rate
    min_lr = args.max_lr / 10 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
    # min_lr = args.max_lr / 20 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
    max_iters = args.max_iters # total number of training iterations
    warmup_iters = args.warmup_iters
    beta1 = args.beta1
    beta2 = args.beta2
    eps = args.eps
    weight_decay = args.weight_decay
    log_interval = args.log_interval
    eval_interval = args.eval_interval
    save_interval = args.save_interval
    eval_iters = args.eval_iters
    resume_from_checkpoint = args.resume_from_checkpoint 
    model_name = args.model_name
    n_layer = args.n_layer
    norm_type = args.norm_type
    max_lr = args.max_lr
    
    # set the dimension of the model input
    d_input = 32100
    # build the llama model 
    model = build_llama_models(model_name, d_input, block_size, device)
    print(sum(p.numel() for p in model.parameters()), ddp_rank, ddp_local_rank, world_size)
    # ckpt name
    ckpt_default_name = f'{wandb_project}_{model_name}_{max_lr}_{wandb_run_name}'
    
    # load the ckpt if needed
    if resume_from_checkpoint == 'auto':
        if os.path.exists(f'{ckpt_path}/{ckpt_default_name}_ckpt.pt'):
            resume_from_checkpoint = ckpt_default_name
        else:
            resume_from_checkpoint = None
    if resume_from_checkpoint:
        ckpt = torch.load(f'{ckpt_path}/{resume_from_checkpoint}_ckpt.pt', map_location=device)
        print(f"recover from step {ckpt['iter_num']}")
        model.load_state_dict(ckpt['model'])
    # wrap model with ddp
    if ddp:
        model = DDP(model, device_ids=[ddp_local_rank])
    if args.faster_path:
        model = torch.compile(model)
    # group the training parameters
    embed_param = [p for name, p in model.named_parameters() if 'embed' in name]
    head_param = [p for name, p in model.named_parameters() if 'lm_head' in name]
    ln_param = [p for name, p in model.named_parameters() if 'norm' in name]
    qk_param = [p for name, p in model.named_parameters() if 'q_proj' in name or 'k_proj' in name]
    vo_param = [p for name, p in model.named_parameters() if 'v_proj' in name or 'o_proj' in name]
    mlp_param = [p for name, p in model.named_parameters() if 'mlp' in name]
    extra_args = {}

    if args.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW
        optimizer = optimizer([
            {'params': embed_param, 'lr': max_lr, "name": "embed", 'weight_decay': weight_decay},
            {'params': head_param, 'lr': max_lr, "name": "head", 'weight_decay': weight_decay},
            {'params': ln_param, 'lr': max_lr, "name": "ln", 'weight_decay': weight_decay},
            {'params': qk_param, 'lr': max_lr, "name": "qk", 'weight_decay': weight_decay},
            {'params': vo_param, 'lr': max_lr, "name": "vo", 'weight_decay': weight_decay},
            {'params': mlp_param, 'lr': max_lr, "name": "mlp", 'weight_decay': weight_decay},
        ], lr=max_lr, betas=(beta1,beta2), weight_decay=weight_decay, **extra_args)
    
    elif args.optimizer == 'LANTON':
        from lanton import LANTON
        sign_params = []
        sign_params += embed_param        # e.g. token embeddings
        sign_params += head_param         # lm_head
        sign_params += ln_param 
        muon_params = []
        muon_params += qk_param            # q_proj / k_proj
        muon_params += vo_param            # v_proj / o_proj
        muon_params += mlp_param           # FFN 
        optimizer = LANTON(
            lr=max_ln_lr,
            wd=weight_decay,
            beta=args.noise_momentum,
            muon_params=muon_params,
            adamw_params=sign_params,
            adaptive_warmup_steps=args.warmup_iters,
            scale1=args.scale1,
            scale2=args.scale2,
        )
    else:
        raise NotImplementedError
    print(f"using {args.optimizer} optimizer")
    # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
    iter_num = 0
    best_val_loss = 1e9
    losses = {"train/loss": [], "val/loss": [], 
              "train/iterval": log_interval, "val/iterval": eval_interval}
    if resume_from_checkpoint:
        optimizer.load_state_dict(ckpt['optimizer'])
        iter_num = ckpt['iter_num']
        if args.optimizer == 'Adaptive_Muon':
            optimizer.step_count = iter_num - 1

        del ckpt
        filepath = f"{ckpt_path}/{ckpt_default_name}_ana.p"
        with open(filepath, "rb") as f:
            losses = pickle.load(f)
            print(f'loading loss checkpoints')
        print("resuming done....")

    # logging
    if wandb_log and master_process:
        import wandb
        wandb.init(project=wandb_project, name=wandb_run_name)
        config = wandb.config 
        config.batch_size = batch_size
        config.gradient_accumulation_steps = gradient_accumulation_steps 
        config.max_iters = max_iters
        config.warmup_iters = warmup_iters  
        config.beta1 = beta1
        config.beta2 = beta2
        config.eps = eps 
        config.weight_decay = weight_decay
        config.seed = args.seed
        config.log_interval = log_interval
        config.eval_interval = eval_interval
        config.save_interval = save_interval
        config.eval_iters = eval_iters
        config.softcapping = args.softcapping
        config.norm_type = norm_type

    t0 = time.time()


    while True:
        # determine and set the learning rate for this iteration
        lrs = []
        for param_group in optimizer.param_groups:
            lr = get_lr(iter_num, min_lr, max_lr, warmup_iters, max_iters) if decay_lr else args.max_lr 
            param_group['lr'] = lr
            lrs.append(lr)

        total_loss = 0
        # forward backward update, with optional gradient accumulation to simulate larger batch size
        for micro_step in range(gradient_accumulation_steps):
            batch_data = next(train_data)['input_ids']
            X = batch_data[:,:-1]
            Y = batch_data[:,1:]
            X, Y = X.to(device, non_blocking=True), Y.to(device, non_blocking=True)
            if ddp:
                model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
            with ctx:
                logits = model(X).logits.view(-1, d_input)
            loss = F.cross_entropy(logits, Y.view(-1), ignore_index=0) # more indent
            (loss / gradient_accumulation_steps).backward()
            total_loss += loss.detach().float() / gradient_accumulation_steps

        if iter_num % log_interval == 0 and iter_num > 5 and master_process:
            lossf = loss.item()

        optimizer.step()
        # flush the gradients as soon as we can, no need for this memory anymore
        optimizer.zero_grad(set_to_none=True)

        # timing and logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        if iter_num % eval_interval == 0 or iter_num == max_iters:
            loss_val = estimate_loss(model, eval_iters, batch_size)
        if iter_num % log_interval == 0:
            reduce(total_loss, 0, ReduceOp.AVG)
        if iter_num % log_interval == 0 and master_process:
            lossf = total_loss.item()
            losses["train/loss"].append(lossf)

            if iter_num % eval_interval == 0 or iter_num == max_iters:
                losses["val/loss"].append(loss_val)
                print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
                if wandb_log:
                    wandb.log({
                        "iter": iter_num,
                        "train/loss": lossf,
                        "val/loss": loss_val,
                        "lr": lrs[0],
                    }, step=iter_num)
            else:
                print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
                if wandb_log:
                    wandb.log({
                        "iter": iter_num,
                        "train/loss": lossf,
                        "lr": lrs[0],
                    }, step=iter_num)

            
            if iter_num % save_interval == 0 and iter_num != 0 and master_process:
                checkpoint = {
                    'model': model.module.state_dict() if ddp else model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'iter_num': iter_num + 1,
                }
                print(f"saving checkpoint ... ", end='')
                atomic_torch_save(checkpoint, f'{ckpt_path}/{ckpt_default_name}_ckpt.pt')
                print(f"saving logs ... ", end='')
                filepath = f"{ckpt_path}/{ckpt_default_name}_ana.p"
                os.makedirs(os.path.dirname(filepath), exist_ok=True)
                with open(filepath, "wb") as f:
                    pickle.dump(losses, f)
                print(f'saved')

        iter_num += 1

        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        # termination conditions
        if iter_num > max_iters:
            break
        if args.debug:
            if iter_num > 1000:
                break

    if master_process:
        with open(f"{ckpt_path}/{ckpt_default_name}_finished_ana.p", "wb") as f:
            pickle.dump(losses, f)
            print(f'saved')
        checkpoint = {
            'model': model.module.state_dict() if ddp else model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'iter_num': iter_num + 1,
        }
        atomic_torch_save(checkpoint, f'{ckpt_path}/{ckpt_default_name}_ckpt.pt')
        print(f"saving checkpoint ... ", end='')
        if wandb_log:
            wandb.finish()
        print(f"saving logs ... ", end='')
    if ddp:
        destroy_process_group()


if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--wandb_log", action='store_true', help="Use Wandb Log.")
    parser.add_argument("--wandb_project", default= 'llama_c4', type=str, help="Wandb project.")
    parser.add_argument("--wandb_run_name", default=None , type=str, help="Wandb run name.")
    parser.add_argument("--model_name", default="0.5B", type=str)
    parser.add_argument("--n_layer", default=12, type=int, help="model depth.")
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")
    parser.add_argument("--batch_size", default=64, type=int, help="Batch size.")
    parser.add_argument("--grad_micro_steps", default=16, type=int, help="Gradient accumulation steps.")
    parser.add_argument("--total_bs", default=1024, type=int, help="Total batch size.")
    parser.add_argument("--log_interval", default=10, type=int, help="Log iterations.")
    parser.add_argument("--eval_interval", default=200, type=int, help="evaluate every n steps.")
    parser.add_argument("--save_interval", default=500, type=int, help="save checkpoint every n steps.")
    parser.add_argument("--eval_iters", default=100, type=int, help="the max iteration when evaluating models.")
    parser.add_argument("--max_lr", default=5e-3, type=float, help="max lr for base learing rate.")
    parser.add_argument("--softcapping", default=0., type=float)
    parser.add_argument("--norm_type", default="pre_norm", choices=["pre_norm", "post_norm", "npost_norm", "rotate_norm", "nrotate_norm"])
    parser.add_argument("--max_iters", default=None, type=int, help="max iterations.")
    parser.add_argument("--warmup_iters", default=1000, type=int, help="warmup iterations.")
    parser.add_argument("--optimizer", default="LANTON")
    parser.add_argument("--dataset", default='c4', type=str, choices=['c4'])
    parser.add_argument("--faster_path", action="store_true")
    parser.add_argument("--switch_iters", default=None, type=int, help="warning: only kept for compatibility")
    parser.add_argument("--beta1", default=0.9, type=float, help="beta1 in AdamW.")
    parser.add_argument("--beta2", default=0.95, type=float, help="beta2 in AdamW.")
    parser.add_argument("--eps", default=1e-8, type=float, help="epsilon in AdamW.")
    parser.add_argument("--weight_decay", default=0.1, type=float, help="weight decay in AdamW.")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--gpu_count", type=int, default=1, help="")
    parser.add_argument('--dist_url', type=str, default="") 
    parser.add_argument("--debug", action="store_true", help="Enable debug mode.")
    parser.add_argument("--noise_momentum", default=0.9, type=float, help="noise momentum in adaptive muon.")
    parser.add_argument("--scale1", default=300, type=float, help="learning rate scale for embed/head layers")
    parser.add_argument("--scale2", default=1.0, type=float, help="learning rate scale for norm layers")
    args = parser.parse_args()
    # run name
    args.wandb_run_name = f"{args.optimizer}_{args.max_lr}_{args.max_iters}"

    os.environ["TOKENIZERS_PARALLELISM"] = "true"
        
    global ddp, ddp_rank, ddp_local_rank, master_process, world_size
    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?

    if ddp:
        init_process_group(backend=backend)
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = args.local_rank if args.local_rank != -1 else int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        device = f'cuda:{ddp_local_rank}'
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
    else:
        # if not ddp, we are running on a single gpu, and one process
        ddp_rank = 0                             #ddp_rank is used in get_batch function so this has to be here also when running locally
        master_process = True
        world_size = 1
    global block_size, train_data, val_data
    block_size = 256
    args.max_iters = args.max_iters or 10000 # 50000 -> 30000 
    args.warmup_iters = args.warmup_iters or 1000 
    if args.debug:
        args.batch_size = 32
        args.wandb_project = f'debug_{args.wandb_project}'
    args.grad_micro_steps = int(args.total_bs // args.batch_size)

    # download the dataset
    dataset = datasets.load_dataset("allenai/c4", "en", cache_dir="data/c4",  streaming=True)

    # load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("t5-small")
    tokenizer.pad_token_id = 0
    tokenizer.eos_token_id = 1

    def tokenize_fun(data):
        output = tokenizer(data["text"], truncation=True, max_length=block_size+1, padding=False)
        return output

    tokenized_data = dataset.map(tokenize_fun, batched=True, remove_columns=["text", "url", "timestamp"])
    tokenized_train_data = datasets.distributed.split_dataset_by_node(tokenized_data['train'], rank=ddp_rank, world_size=world_size)
    tokenized_val_data = datasets.distributed.split_dataset_by_node(tokenized_data['validation'], rank=ddp_rank, world_size=world_size)
    collate_fn = DataCollatorForLanguageModeling(tokenizer, mlm=False)
    # construct the trainining/validation dataloader
    train_dataloader = DataLoader(tokenized_train_data, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, pin_memory_device=device)
    val_dataloader = DataLoader(tokenized_val_data, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, pin_memory_device=device)

    train_data = iter(train_dataloader)
    val_data = cycle(val_dataloader)
    # setup the random seed
    setup_seed(args.seed)
    
    train(args)
