from datasets import load_dataset
import datasets
from datasets.distributed import split_dataset_by_node
from transformers import (
    AutoTokenizer,
    AutoConfig,
    LlamaForCausalLM,
    DataCollatorForLanguageModeling,
    get_cosine_schedule_with_warmup,
)
import torch
import math
import torch.distributed as dist
import os
from loguru import logger
from tqdm import tqdm
import argparse
from torch.utils.data import DataLoader
import time
from compressor.compress_hook import create_compression_hook
import numpy as np
import yaml
import json
import itertools
from safetensors.torch import load_file

C4_LENGTH = 364_860_000

def parse_args(args):
    parser = argparse.ArgumentParser()

    # Training 
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--batch_size", default=16, type=int, help="batch size per round")
    parser.add_argument("--total_batch_size", default=32, type=int, help="batch size per step")
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--num_training_steps", type=int, default=10000)
    parser.add_argument("--grad_clip", type=float, default=0.0)
    parser.add_argument("--warmup", type=int, default=1000)
    parser.add_argument("--use_wandb", default=False, action="store_true")
    parser.add_argument("--use_tqdm", default=False, action="store_true")
    parser.add_argument("--p_t", default="0.5", type=float, help='the probability for lazy sampling')
    parser.add_argument("--lazy_sampling", action="store_true")
    parser.add_argument("--lazy_schedule", default="constant", type=str)
    parser.add_argument("--scale", action="store_true")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--new_batch_protect", action="store_true")
    parser.add_argument("--warmup_protect", type=int, default=2000)

    # eval
    parser.add_argument("--eval_freq", default=500, type=int)
    
    # compression
    parser.add_argument("--comp_path", type=str)
    parser.add_argument("--use_compression", action="store_true")
    parser.add_argument("--cp_method", default=None, type=str)
    parser.add_argument("--cp_param", default=None, type=str)

    # model
    parser.add_argument("--model_config", type=str, default="configs/llama_1b.json")

    # optimizer
    parser.add_argument("--optimizer", choices=['sgd', 'adam'], default='adam',
                        type=str, help="assign the optimization algorithm")
    parser.add_argument("--momentum", default=0, type=float)
    parser.add_argument("--nesterov", action="store_true")
    parser.add_argument("--per_layer_weight_update", action="store_true")

    # checkpoint
    parser.add_argument("--save_dir", default="checkpoints", type=str)
    parser.add_argument("--save_freq", default=2000, type=int)
    parser.add_argument("--continue_from", type=str, default=None)
    
    # wandb
    parser.add_argument("--wandb_run_name", default='clapping')

    args = parser.parse_args(args)

    return args

def log(info):
    if dist.get_rank() == 0:
        logger.info(f"[{int(os.environ.get('RANK'))}]: " + info)

def ddp_setup():
    dist.init_process_group(backend="nccl")

def apply_compression_hooks(model, compression_config, device):
    # Apply compression hooks to specified layers
    for _, layer_config in compression_config.items():
        layer_idx = layer_config.get('layer_idx', None)
        if layer_idx is None:
            continue  # Skip if layer_idx is not specified
        # Find the module corresponding to the layer_idx
        target_layer = None
        for name, module in model.named_modules():
            if name.endswith(f'layers.{layer_idx}'):
                target_layer = module
                break
        if target_layer is None:
            continue  # Layer not found in this stage
        # Create and register the compression hook
        compression_hook = create_compression_hook(layer_idx, layer_config, device)
        handle = target_layer.register_forward_hook(compression_hook)
        # Store the handle for potential removal later
        if not hasattr(model, 'compression_handles'):
            model.compression_handles = []
        model.compression_handles.append(handle)
        log(f'Applied compression hook to layer: {name}')

class cosine_lazy_sampling_schedule():
    def __init__(self, num_training_steps, p_t, eta_min=0):
        self.eta_max = p_t
        self.eta_min = eta_min
        self.T_max = num_training_steps
    
    def __call__(self, current_batch):
        cosine_part = 0.5 * (1 + math.cos(math.pi * current_batch / self.T_max))
        return 1 - cosine_part * (self.eta_max - self.eta_min)

def block_hook(model, compression_config, value):
    for _, layer_config in compression_config.items():
        layer_idx = layer_config.get('layer_idx', None)
        layer = getattr(model.model.layers, str(layer_idx))
        setattr(layer, 'block', value)

def set_new_batch(model, compression_config, value):
    for _, layer_config in compression_config.items():
        layer_idx = layer_config.get('layer_idx', None)
        layer = getattr(model.model.layers, str(layer_idx))
        setattr(layer, 'new_batch', value)

def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)
    
def save_checkpoint(model, optimizer, schedule, update_step, args):
    current_model_directory = f"{args.save_dir}/model_{update_step}"
    log(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
    os.makedirs(args.save_dir, exist_ok=True)
    model.generation_config.pad_token_id = 0
    model.save_pretrained(current_model_directory, max_shard_size='100GB')
    # thie is because llama config use -1 to present pad_token_id, which is conflict with 
    # save_pretrained, so we need to reset the param.
    model.generation_config.pad_token_id = -1

    optimizer_checkpoint = {
        "optimizer": optimizer.state_dict(),
        "schedule": schedule.state_dict(),
        "update_step": update_step,
    }
    torch.save(optimizer_checkpoint, f"{current_model_directory}/optimizer.pt")

    training_state_checkpoint = {
        "update_step": update_step,
    }
    with open(f"{current_model_directory}/training_state.json", "w") as f:
        json.dump(training_state_checkpoint, f, indent=4)

def load_checkpoint(model, optimizer, schedule, args, device):
    log("*" * 40)
    log(f"Loading model from {args.continue_from}")
    checkpoint_path = os.path.join(args.continue_from, "model.safetensors")
    checkpoint = load_file(checkpoint_path, device=device)
    model.load_state_dict(checkpoint, strict=True)
    opt_checkpoint_path = os.path.join(args.continue_from, "optimizer.pt")
    opt_checkpoint = torch.load(opt_checkpoint_path, map_location=torch.device(device))
    optimizer.load_state_dict(opt_checkpoint['optimizer'])
    schedule.load_state_dict(opt_checkpoint['schedule'])

    log(f"Model and Optimizer successfully loaded (strict=True policy)")

    if os.path.exists(os.path.join(args.continue_from, "training_state.json")):
        log(f"Loading training state like global_step, update_step, and tokens_seen from {args.continue_from}")
        with open(os.path.join(args.continue_from, "training_state.json")) as f:
            _old_state = json.load(f)
        update_step = _old_state["update_step"]
        log(f"update_step       : {update_step}")
        log(f"Will train for {args.num_training_steps - update_step} update steps")
    else:
        logger.warning(f"Did not find training state in {args.continue_from}, global step will start from zero")
    log("*" * 40)
    return update_step
    

def main(args):
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    # Set different random seed 
    torch.manual_seed(args.seed + rank)
    torch.cuda.manual_seed_all(args.seed + rank)
    
    log("Process group initialize")
    log("*" * 40)
    log("Start training with arguments")
    for k, v in vars(args).items():
        log(f"{k:30} {v}")
    log("*" * 40)

    device = f"cuda:{local_rank}"

    # ensure grad_accumulation is integer
    assert args.total_batch_size % args.batch_size == 0, "grad accumulation must be integer"
    grad_accumulation = args.total_batch_size // args.batch_size

    # tokenizer
    #tokenizer = AutoTokenizer.from_pretrained("/data/t5", model_max_length=args.max_length)
    tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length)



    # model
    model_config = AutoConfig.from_pretrained(args.model_config)
    model = LlamaForCausalLM(model_config).to(device)

    # optimizer
    trainable_param = [p for p in model.parameters() if p.requires_grad == True]

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(trainable_param, 
                                    lr=args.lr, 
                                    momentum=args.momentum,
                                    nesterov=args.nesterov,
                                    foreach=False)
        # schedule
        schedule = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=args.warmup,
                                                num_training_steps=args.num_training_steps + 1)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(trainable_param,
                                     lr=args.lr,
                                     )

        # schedule
        schedule = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=args.warmup,
                                                num_training_steps=args.num_training_steps)


    if args.continue_from:
        update_step = load_checkpoint(model, optimizer, schedule, args, device)
        model.to(device)

    # synchronize init param
    for param in model.parameters():
        dist.broadcast(param, src=0)

    # dataset
    ds = datasets.load_dataset("/data/datasets/c4/en", split="train", streaming=True)
    val_ds = datasets.load_dataset("/data/datasets/c4/en", split="validation", streaming=True)
    #if args.continue_from:
        ## set dataset to start from the checkpoint
        #ds = datasets.Dataset.from_list(list(itertools.islice(ds, update_step * args.total_batch_size, None)))
        
    val_ds = val_ds.shuffle(seed=42+rank)

    val_ds = datasets.Dataset.from_list(list(itertools.islice(val_ds, args.total_batch_size * 100)))
    
    #ds = datasets.Dataset.from_list(list(itertools.islice(ds, (args.num_training_steps + 10) * args.total_batch_size)))

    def tokenize_fun(data):
        output = tokenizer(data["text"],
                           truncation=True,
                           max_length=args.max_length,
                           padding="max_length",)
        return output


    dataset = ds.map(tokenize_fun, batched=True, remove_columns=["url", "text", "timestamp"])
    val_dataset = val_ds.map(tokenize_fun, batched=True, remove_columns=["url", "text", "timestamp"])
    dataset = split_dataset_by_node(dataset, rank, world_size)
    val_dataset = split_dataset_by_node(val_dataset, rank, world_size)
    #if args.lazy_sampling == True:
        #dataset = LazyBatchC4()
    collate_fun = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    def indexed_collator(features):
        batch = collate_fun(features)
        if isinstance(features[0], dict) and "indices" in features[0]:
            batch["indices"] = torch.tensor([int(f["indices"]) for f in features])
        return batch

    if args.lazy_sampling:
        dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=indexed_collator)
    else:
        dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fun)

    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fun)

    if rank == 0 and args.use_tqdm: 
        if args.continue_from:
            pbar = tqdm(total=args.num_training_steps - update_step, desc="update step", ncols=80)
        else:
            pbar = tqdm(total=args.num_training_steps, desc="update step", ncols=80)

    local_step = 0
    if not args.continue_from:
        update_step = 0
    token_seen = 0
    token_seen_before = 0
    pad_idx = tokenizer.pad_token_id

    if args.use_compression:
        config = load_config(args.comp_path)
        compression_config = config['compression_config']
        # overlap the compression config using args.
        for layer in compression_config.keys():
            if args.cp_method is not None:
                compression_config[layer]['forward'] = args.cp_method
                compression_config[layer]['backward'] = args.cp_method
                if args.cp_method == 'topk':
                    compression_config[layer]['forward-params'] = {'topk': float(args.cp_param)}
                    compression_config[layer]['backward-params'] = {'topk': float(args.cp_param)}
                elif args.cp_method == 'natural_compress':
                    compression_config[layer]['forward-params'] = {'k': int(args.cp_param)}
                    compression_config[layer]['backward-params'] = {'k': int(args.cp_param)}
                elif args.cp_method == 'block_compress':
                    param = [float(item) if '.' in item else int(item) for item in args.cp_param.split(',')]
                    if len(param) == 3:
                        # don't use fine grain
                        param.append(False)
                        log("Do not use fine-grain")
                    else:
                        log(f"Granuarity is {param[3]}")
                        log(f"The Sequence will be divided into {args.max_length // param[3]} part to quantize")
                    compression_config[layer]['forward-params'] = {'k': param[0],
                                                                   'k_error': param[1],
                                                                   'topk': param[2],
                                                                   'is_row': False,
                                                                   'grain': param[3]}
                    compression_config[layer]['backward-params'] = {'k': param[0],
                                                                   'k_error': param[1],
                                                                   'topk': param[2],
                                                                   'is_row': False,
                                                                   'grain': param[3]}
                elif args.cp_method == 'block_natural':
                    param = [float(item) if '.' in item else int(item) for item in args.cp_param.split(',')]
                    compression_config[layer]['forward-params'] = {'k': param[0],
                                                                   'topk': param[1],}
                    compression_config[layer]['backward-params'] = {'k': param[0],
                                                                   'topk': param[1],}
                                                                   

        apply_compression_hooks(model, compression_config, device)


    if args.scale:
        scaler = torch.amp.GradScaler()

    if args.lazy_schedule != 'constant':
        log("Set Cosine Lazy Sampling p_t scheduler")
        log(f"Prob of new batch: {1 - args.p_t} -------> {1}")
        args.p_t = cosine_lazy_sampling_schedule(args.num_training_steps, args.p_t)
    else:
        num = float(args.p_t)
        args.p_t = lambda x: num
    
    if args.lazy_sampling == False:
        args.p_t = lambda x: 1.0


    for epoch in range(2):
        log(f"Start Epoch: {epoch}")
        if args.lazy_sampling:
            rng = np.random.RandomState(42 + epoch + rank)
        # create generator 
        dl = iter(dataloader)
        indices = [i for i in range(args.batch_size)]
        if epoch == 0 and args.warmup_protect:
            args.if_lazy = args.lazy_sampling
            if args.if_lazy:
                args.lazy_sampling = False
            block_hook(model, compression_config, True)
        log(f"Start Warmup Protection: warmup protection step {args.warmup_protect}")
        batch = next(dl)
        update_time = time.time()
        epoch_len = C4_LENGTH // world_size // args.batch_size
        if args.continue_from and epoch == 0:
            log(f"Start to filter data from first {update_step * grad_accumulation} batch.")
            #dl = itertools.islice(dl, update_step * grad_accumulation, None)
            if args.if_lazy:
                indices = [i + update_step * grad_accumulation for i in range(args.batch_size)]
            log(f"May need a few of times. Please wait...")
            for idx in range(update_step * grad_accumulation):
                # drop the used data``
                _ = next(dl)
                if idx % 500 == 0:
                    log(f"Already drop {idx} batch. Please wait...")
            log(f"Already drop {idx} batch, start from {idx+1} batch")
            time_per_iter = 0
            if update_step >= args.warmup_protect:
                if args.if_lazy:
                    args.lazy_sampling = True
                block_hook(model, compression_config, False)
                log("Checkpoint has already finish Warmup Protection, Stop Warmup Protection")
            epoch_len -= update_step 
        for _ in range(epoch_len):
            # XXX revise for fetch new batch per micro batch
            if not args.lazy_sampling or rng.random() < args.p_t(update_step) or update_step % args.eval_freq == 0 or update_step == 1 or (local_step % grad_accumulation == 0):
                if args.new_batch_protect:
                    set_new_batch(model, compression_config, True)
                batch = next(dl)
                if args.if_lazy:
                    indices = [x + args.batch_size for x in indices]
            else:
                if args.new_batch_protect:
                    set_new_batch(model, compression_config, False)
                for key in batch.keys():
                    batch[key] = batch[key].detach().clone()
            local_step += 1

            if update_step > args.num_training_steps:
                log(f"attain assigned training step {args.num_training_steps}. Stop Training")
                print(f"Rank {rank} stopping training")
                break
            batch = {k: v.to(device) for k, v in batch.items()}
            labels = batch["input_ids"].clone()
            labels[labels == pad_idx] = -100
            token_seen += (batch["input_ids"] != pad_idx).sum().item() * world_size


            # set indices
            if args.lazy_sampling:
                for layer in compression_config.keys():
                    layer_idx = compression_config[layer]['layer_idx']
                    layer_path = model.model.layers[layer_idx] 
                    setattr(layer_path, 'current_indices', indices)

            if args.scale:
                with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                    loss = model(**batch).loss
                    scaled_loss = loss / grad_accumulation
                scaler.scale(loss).backward()
            else:
                loss = model(**batch).loss
                scaled_loss = loss / grad_accumulation
                scaled_loss.backward()

            if local_step % grad_accumulation != 0 or local_step == 0:
                continue

            for p in model.parameters():
                if p.requires_grad:
                    dist.all_reduce(p.grad.data, op=dist.ReduceOp.AVG)

            # the below code is only executed during the update step
            if args.scale:
                scaler.unscale_(optimizer)

            # add grad cliping
            if args.grad_clip != 0.0:
                torch.nn.utils.clip_grad_norm_(trainable_param, args.grad_clip)
            
            if args.scale:
                scaler.step(optimizer)
                scaler.update()
                schedule.step()
                optimizer.zero_grad()
            else:
                optimizer.step()
                optimizer.zero_grad()
                schedule.step()

            if rank == 0 and args.use_tqdm:
                pbar.update(1)
            

            update_step += 1
            update_time = time.time() - update_time
            token_in_update = token_seen - token_seen_before
            token_seen_before = token_seen
            batch_in_update = grad_accumulation * world_size

            if rank == 0 and args.use_wandb:
                record_dict = {
                    "loss": loss.item(),
                    "update_step": update_step,
                    "throughput_tokens": token_in_update / update_time,
                    "throughput_examples": args.total_batch_size * world_size / update_time,
                    "throughput_batchs": batch_in_update,
                }
                record_dict.update({"lr": optimizer.param_groups[0]["lr"]})
            if update_step == 1:
                time_per_iter = update_time
            else:
                time_per_iter = 0.9 * time_per_iter + 0.1 * update_time
            remain_total_seconds = time_per_iter * (args.num_training_steps - update_step)
            lr = optimizer.param_groups[0]["lr"]
            log(f"step: {update_step}/{args.num_training_steps} Loss: {loss:.8f} Lr: {lr:.5f} Pnew: {args.p_t(update_step):.3f}")
            if update_step % 10 == 0:
                hours = int(remain_total_seconds // 3600)
                minutes = int((remain_total_seconds % 3600) // 60)
                seconds = int(remain_total_seconds % 60)
                log(f"ETA: {hours:02d}:{minutes:02d}:{seconds:02d}")
                

            if update_step % args.eval_freq == 0 or update_step == 1:
                log(f"start eval in step {update_step}")
                model.eval()
                loss_sum = 0
                eval_length = 0
                with torch.no_grad():
                    for _, batch in enumerate(val_dataloader):
                        eval_length += 1
                        batch = {k:v.to(device) for k, v in batch.items()}
                        if args.scale:
                            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                                outputs = model(**batch)
                                loss = outputs.loss
                        else:
                            loss = model(**batch).loss
                        loss_sum += loss
                    loss_avg = loss_sum /eval_length
                    dist.all_reduce(loss_avg, op=dist.ReduceOp.AVG)
                log(f"Eval loss: {loss_avg} PPL: {math.exp(loss_avg)}")
                model.train()

            if update_step == args.warmup_protect and args.warmup_protect:
                if args.if_lazy:
                    args.lazy_sampling = True
                block_hook(model, compression_config, False)
                log("Finish Warmup Protection")

            if (update_step % args.save_freq == 0 or update_step == args.warmup_protect) and rank == 0:
                save_checkpoint(model, optimizer, schedule, update_step, args)
            update_time = time.time()

    log("finish training")
    dist.destroy_process_group()
    if rank == 0 and args.use_tqdm:
        pbar.close()

if __name__ == "__main__":
    ddp_setup()
    args = parse_args(None)
    main(args)







