from argparse import ArgumentParser
import glob
import math
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import math
import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.utils.data import DataLoader
from functools import partial
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually
from lit_gpt.model import GPT, Block, Config, CausalSelfAttention
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from lit_gpt import FusedCrossEntropyLoss, DistillLoss
import random

from prts_lit import (
    bert2BERTConfig, 
    IncubationConfig, 
    GradualIncubationConfig, 
    LiGOConfig, 
    SolarConfig, 
    StackingptConfig, 
    MilkConfig, 
    GradualStakcingConfig,
    DistillationConfig,
    MsgConfig,
    MsltConfig,
    ZeroConfig,
    get_prts_model
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.cuda.amp import autocast as autocast


train_config = [
    ("train_slimpajama", 1)
]

def main(
    model_name = "tiny_LLaMA_1b",
    name = "tiny_LLaMA_1b",
    out_dir = Path("out"),
    train_data_dir: Path = "datapath",
    val_data_dir: Optional[Path] = None,
    checkpoint_path: Optional[str] = None,
    # Hyperparameters
    devices = 8,
    num_nodes = 4,
    global_batch_size = 360,
    learning_rate = 2e-4,
    min_lr = 2e-5,
    micro_batch_size = 6,
    max_step = 10000,
    warmup_steps = 0 ,
    log_step_interval = 1,
    eval_iters = 1000000,
    save_step_interval = 2000,
    eval_step_interval = 2000,
    weight_decay = 1e-1,
    beta1 = 0.9,
    beta2 = 0.95,
    grad_clip = 1.0,
    decay_lr = True,
    method: str = "scratch",
    config_path: str = "",
    src_init_path: Optional[str] = None,
    resume_id: int = 0,
    resume_ckpt: Path = None,
) -> None:
    batch_size = global_batch_size // (devices * num_nodes)
    gradient_accumulation_steps = batch_size // micro_batch_size
    assert gradient_accumulation_steps > 0
    warmup_iters = warmup_steps * gradient_accumulation_steps
    max_iters = max_step * gradient_accumulation_steps
    lr_decay_iters = max_iters
    log_iter_interval = log_step_interval * gradient_accumulation_steps

    # logger = step_csv_logger("log", name, flush_logs_every_n_steps=log_iter_interval)
    global hparams
    hparams = {
        "model_name": model_name,
        "name": name,
        "out_dir": out_dir,
        "train_data_dir": train_data_dir,
        "val_data_dir": val_data_dir,
        "checkpoint_path": checkpoint_path,
        "devices": devices,
        "num_nodes": num_nodes,
        "global_batch_size": global_batch_size,
        "learning_rate": learning_rate,
        "min_lr": min_lr,
        "micro_batch_size": micro_batch_size,
        "max_step": max_step,
        "warmup_steps": warmup_steps,
        "log_step_interval": log_step_interval,
        "eval_iters": eval_iters,
        "save_step_interval": save_step_interval,
        "eval_step_interval": eval_step_interval,
        "weight_decay": weight_decay,
        "beta1": beta1,
        "beta2": beta2,
        "grad_clip": grad_clip,
        "decay_lr": decay_lr,
        "method": method,
        "config_path": config_path,
        "src_init_path": src_init_path,
        "resume_id": resume_id,
        "resume_ckpt": resume_ckpt
    }
    wandb_logger = WandbLogger()
    def setup(
        devices: int = 8,
        train_data_dir: Path = Path("data/redpajama_sample"),
        val_data_dir: Optional[Path] = None,
        precision: Optional[str] = None,
        tpu: bool = False,
        resume_id: int = 0,
        resume_ckpt: Path = None,
    ) -> None:
        precision = precision or get_default_supported_precision(training=True, tpu=tpu)

        if devices > 1:
            if tpu:
                # For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
                devices = "auto"
                strategy = XLAStrategy(sync_module_states=False)
            else:
                strategy = FSDPStrategy(
                    auto_wrap_policy={Block},
                    activation_checkpointing_policy=None,
                    state_dict_type="full",
                    limit_all_gathers=True,
                    cpu_offload=False,
                )
        else:
            strategy = "auto"

        fabric = L.Fabric(accelerator='gpu', devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=[wandb_logger])
        fabric.print(hparams)
        fabric.launch(main, train_data_dir, val_data_dir, resume_id, resume_ckpt)
        # main(fabric, train_data_dir, val_data_dir, resume)


    def main(fabric, train_data_dir, val_data_dir, resume_id, resume_ckpt):
        monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval)

        if fabric.global_rank == 0:
            out_dir.mkdir(parents=True, exist_ok=True)

        # config = Config.from_name(model_name)
        if method == "scratch":
            config = Config.from_name(model_name)
        elif method == "b2b":
            config = bert2BERTConfig.from_pretrained(config_path)
        elif method == "solar":
            config = SolarConfig.from_pretrained(config_path)
        elif method == "ligo":
            config = LiGOConfig.from_pretrained(config_path)
        elif method == "stacking":
            config = StackingptConfig.from_pretrained(config_path)
        elif method == "milk":
            config = MilkConfig.from_pretrained(config_path)
        elif method == 'gs':
            config = GradualStakcingConfig.from_pretrained(config_path)
        elif method == 'distill':
            config = DistillationConfig.from_pretrained(config_path)
        elif method == 'msg':
            config = MsgConfig.from_pretrained(config_path)
        elif method == 'mslt':
            config = MsltConfig.from_pretrained(config_path)
        elif method == 'zero':
            config = ZeroConfig.from_pretrained(config_path)

        train_dataloader, val_dataloader = create_dataloaders(
            batch_size=micro_batch_size,
            block_size=config.block_size,
            fabric=fabric,
            train_data_dir=train_data_dir,
            val_data_dir=val_data_dir,
            seed=3407,
        )
        if val_dataloader is None:
            train_dataloader = fabric.setup_dataloaders(train_dataloader)
        else:
            train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)

        fabric.seed_everything(3407)  # same seed for every process to init model (FSDP)
        if checkpoint_path is not None:
            fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
        t0 = time.perf_counter()
        with fabric.init_module(empty_init=False):
            if method != "scratch":
                # assert src_init_path is not None
                src_model = None if method in ['stacking', 'msg', 'zero'] else GPT(Config.from_name(config.src_config_name))
                if src_model is not None:
                    src_model.apply(partial(src_model._init_weights ,n_layer=src_model.config.n_layer))
                trg_model = None if method in ['b2b', 'solar', 'stacking', 'gs', 'mslt'] else GPT(Config.from_name(config.trg_config_name))
                if trg_model is not None:
                    trg_model.apply(partial(trg_model._init_weights ,n_layer=trg_model.config.n_layer))
                # fabric.load_raw(src_init_path, src_model, strict=True)
                if src_init_path is not None:
                    src_model.load_state_dict(torch.load(src_init_path)['model'])
                model = get_prts_model(src_model, trg_model, prts_config=config)
            else:
                model = GPT(config)
        if checkpoint_path is not None:
            state_dict = torch.load(checkpoint_path)
            if 'model' in state_dict:
                model.load_state_dict(state_dict['model'])
            else:
                model.load_state_dict(state_dict)
            state_dict = None

        model = fabric.setup(model)
        fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
        fabric.print(f"Total parameters {num_parameters(model):,}")
        
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
        )
        # import bitsandbytes as bnb
        # optimizer = bnb.optim.AdamW8bit(
        #     model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)
        # )
        # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True)
        optimizer = fabric.setup_optimizers(optimizer)

        state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0}

        # if resume is True:
        #     resume = sorted(out_dir.glob("*.pth"))[-1]
        if resume_ckpt:
            fabric.print(f"Resuming training from {resume_ckpt}")
            fabric.load(resume_ckpt, state)

        train_time = time.perf_counter()
        train(fabric, state, train_dataloader, val_dataloader, monitor, resume_id, resume_ckpt)
        fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
        if fabric.device.type == "cuda":
            fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")


    def train(fabric, state, train_dataloader, val_dataloader, monitor, resume_id, resume_ckpt):
        model = state["model"]
        optimizer = state["optimizer"]

        if val_dataloader is not None:
            validate(fabric, model, val_dataloader)  # sanity check

        # with torch.device("meta"):
        #     meta_model = GPT(Config.from_name(config.trg_config_name))
        #     # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
        #     # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
        #     # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead
        #     estimated_flops = estimate_flops(meta_model) * micro_batch_size
        #     fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
        #     x = torch.randint(0, 1, (micro_batch_size, model.config.block_size))
        #     # measured_flos run in meta. Will trigger fusedRMSNorm error
        #     #measured_flops = measure_flops(meta_model, x)
        #     #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
        #     del meta_model, x

        total_lengths = 0
        total_t0 = time.perf_counter()

        if fabric.device.type == "xla":
            import torch_xla.core.xla_model as xm

            xm.mark_step()
        
        
        initial_iter = state["iter_num"]
        curr_iter = 0
                
        loss_func = FusedCrossEntropyLoss()
        distill_loss_func = DistillLoss()
        for  train_data in train_dataloader:
            # resume loader state. This is not elegant but it works. Should rewrite it in the future.
            if resume_id > 0:
                if curr_iter < resume_id:
                    curr_iter += 1
                    continue
                else:
                    resume_id = 0
                    curr_iter = -1
                    fabric.barrier()
                    fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0))
            if resume_ckpt is not None:
                if curr_iter < initial_iter:
                    curr_iter += 1
                    continue
                else:
                    resume_ckpt = None
                    curr_iter = -1
                    fabric.barrier()
                    fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0))
            if state["iter_num"] >= max_iters:
                break
            
            # determine and set the learning rate for this iteration
            lr = get_lr(state["iter_num"]) if decay_lr else learning_rate
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

            iter_t0 = time.perf_counter()
            input_ids = train_data[:, 0 : model.config.block_size].contiguous()
            targets = train_data[:, 1 : model.config.block_size + 1].contiguous()

            is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0
            with fabric.no_backward_sync(model, enabled=is_accumulating):
                logits = model(input_ids)
                if hparams['method'] == 'distill':
                    teacher_logits, student_logits = logits
                    loss = distill_loss_func(student_logits, teacher_logits)
                else:
                    loss = loss_func(logits, targets)
                    # loss = chunked_cross_entropy(logits, targets, chunk_size=0)
                fabric.backward(loss / gradient_accumulation_steps)

            if not is_accumulating:
                fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
                optimizer.step()
                optimizer.zero_grad()
                state["step_count"] += 1

                if hparams['method'] == 'milk':
                    if isinstance(state["model"].global_mask_schedule, List):
                        for mask_schedule in state["model"].global_mask_schedule:
                            mask_schedule.step()
                    else:
                        state["model"].global_mask_schedule.step()
                if hparams['method'] == 'gs' or hparams['method'] == 'mslt':
                    state["model"].step(state["step_count"])
                
                if hparams['method'] == 'msg':
                    state["model"].step()

            elif fabric.device.type == "xla":
                xm.mark_step()
            state["iter_num"] += 1
            # input_id: B L 
            total_lengths += input_ids.size(1)
            t1 = time.perf_counter()
            fabric.print(
                    f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:"
                    f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
                    f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 
                    # print days as well
                    f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. "
                )
    
            monitor.on_train_batch_end(
                state["iter_num"] * micro_batch_size,
                t1 - total_t0,
                # this assumes that device FLOPs are the same and that all devices have the same batch size
                fabric.world_size,
                state["step_count"],
                lengths=total_lengths,
                train_loss = loss.item()
            )
                
            if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0:
                
                t0 = time.perf_counter()
                val_loss = validate(fabric, model, val_dataloader)
                t1 = time.perf_counter() - t0
                monitor.eval_end(t1)
                fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms")
                fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens":  model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"])
                fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens":  model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"])
                fabric.barrier()
            if not is_accumulating and state["step_count"] % save_step_interval == 0:
                if hparams['method'] == 'ligo':
                    with FSDP.summon_full_params(model, with_grads=False):
                        with autocast():
                            state_dict = state['model'].get_trg_params()
                            hyper_state_dcit = state['model'].get_hypernet_dict()
                        if fabric.global_rank == 0:
                            checkpoint_trg_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth"
                            checkpoint_hypernet_path = out_dir / f"iter-{state['iter_num']:06d}-hypernet-ckpt.pth"
                            fabric.print(f"Saving checkpoint to {str(checkpoint_trg_path)!r}")
                            torch.save(state_dict, checkpoint_trg_path)
                            fabric.print(f"Saving checkpoint to {str(checkpoint_hypernet_path)!r}")
                            torch.save(hyper_state_dcit, checkpoint_hypernet_path)
                        fabric.barrier()
                else:
                    checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth"
                    fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}")
                    fabric.save(checkpoint_path, state)

            
    @torch.no_grad()
    def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor:
        fabric.print("Validating ...")
        model.eval()

        losses = torch.zeros(eval_iters, device=fabric.device)
        for k, val_data in enumerate(val_dataloader):
            if k >= eval_iters:
                break
            input_ids = val_data[:, 0 : model.config.block_size].contiguous()
            targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
            logits = model(input_ids)
            loss = chunked_cross_entropy(logits, targets, chunk_size=0)

            # loss_func = FusedCrossEntropyLoss()
            # loss = loss_func(logits, targets)
            losses[k] = loss.item()
            
        out = losses.mean()

        model.train()
        return out


    def create_dataloader(
        batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train"
    ) -> DataLoader:
        datasets = []
        data_config = train_config if split == "train" else None # TODO
        for prefix, _ in data_config:
            filenames = sorted(glob.glob(str(data_dir / f"{prefix}*")))
            random.seed(seed)
            random.shuffle(filenames)

            dataset = PackedDataset(
                filenames,
                # n_chunks control the buffer size. 
                # Note that the buffer size also impacts the random shuffle
                # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)
                n_chunks=8,
                block_size=block_size,
                shuffle=shuffle,
                seed=seed+fabric.global_rank,
                num_processes=fabric.world_size,
                process_rank=fabric.global_rank,
            )
            datasets.append(dataset)

        if not datasets:
            raise RuntimeError(
                f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
            )

        weights = [weight for _, weight in data_config]
        sum_weights = sum(weights)
        weights = [el / sum_weights for el in weights]

        combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)

        return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)


    def create_dataloaders(
        batch_size: int,
        block_size: int,
        fabric,
        train_data_dir: Path = Path("data/redpajama_sample"),
        val_data_dir: Optional[Path] = None,
        seed: int = 12345,
    ) -> Tuple[DataLoader, DataLoader]:
        # Increase by one because we need the next word as well
        effective_block_size = block_size + 1
        train_dataloader = create_dataloader(
            batch_size=batch_size,
            block_size=effective_block_size,
            fabric=fabric,
            data_dir=train_data_dir,
            shuffle=True,
            seed=seed,
            split="train"
        )
        val_dataloader = (
            create_dataloader(
                batch_size=batch_size,
                block_size=effective_block_size,
                fabric=fabric,
                data_dir=val_data_dir,
                shuffle=False,
                seed=seed,
                split="validation"
            )
            if val_data_dir
            else None
        )
        return train_dataloader, val_dataloader


    # learning rate decay scheduler (cosine with warmup)
    def get_lr(it):
        # 1) linear warmup for warmup_iters steps
        if it < warmup_iters:
            return learning_rate * it / warmup_iters
        # 2) if it > lr_decay_iters, return min learning rate
        if it > lr_decay_iters:
            return min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
        return min_lr + coeff * (learning_rate - min_lr)

    setup(devices, train_data_dir, val_data_dir, resume_id=resume_id, resume_ckpt=resume_ckpt)

if __name__ == "__main__":
    # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
    # torch.backends.cuda.enable_flash_sdp(False)
    torch.set_float32_matmul_precision("high")

    from jsonargparse import CLI

    CLI(main)
