import os
import json
import math
import sys
import time
from pathlib import Path
from typing import Optional, Union,Literal,Tuple
from tqdm import tqdm
import lightning as L
import numpy as np
import torch
from types import SimpleNamespace
from lightning.fabric.strategies import FSDPStrategy
from torch.utils.data import DataLoader, IterableDataset
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import Config
from lit_gpt.model_code import GPT, Block
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import (
    chunked_cross_entropy,
    get_default_supported_precision,
    num_parameters,
    find_resume_path
)

import wandb





def setup(
    model_name: str = "pythia-160m",
    out_dir: Path = "/data/nhird_pretrained/flop_4e18/12b",
    data_dir: Path = "/data/nhird_hf_rope",
    eval_iters: int = 200,
    log_interval: int = 50, 
    learning_rate: float = 6e-4,
    batch_size: int = 64,
    micro_batch_size: int = 2,
    max_iters: int = 100000,
    weight_decay: float = 1e-1,
    beta1: float = 0.9,
    beta2: float = 0.95,
    grad_clip: float = 1.0,
    decay_lr: bool = True,
    warmup_ratio: float = 0.1,
    stable_ratio: float = 0.8,
    decay_ratio: float = 0.1,
    devices: int = 1,
    nodes: int = 1,
    precision: Optional[str] = None,
    resume: Union[bool, Literal["auto"], Path] = "auto",
    seed: int = 1337
) -> None:
    precision = precision or get_default_supported_precision(training=True)

    gradient_accumulation_steps = batch_size // micro_batch_size // devices // nodes
    assert gradient_accumulation_steps > 0
    eval_interval = int(max_iters // 50)
    save_interval = int(max_iters // 8)
    max_iters = (
        max_iters * gradient_accumulation_steps
    )  # num_epochs * (epoch_size // micro_batch_size) // devices
    warmup_iters = int(warmup_ratio * max_iters)   # 10% for warmup
    stable_iters = int(stable_ratio * max_iters)   # 80% for stable phase
    decay_iters = int(decay_ratio * max_iters)  # 10% for decay phase        
    
    args = locals()
    args = SimpleNamespace(**args)

    if devices > 1:
        strategy = FSDPStrategy(
            auto_wrap_policy={Block},
            activation_checkpointing_policy={Block},
            state_dict_type="full",
            limit_all_gathers=True,
            cpu_offload=False,
        )
    else:
        strategy = "auto"
    fabric = L.Fabric(
        devices=devices,num_nodes=nodes, strategy=strategy, precision=precision, loggers=None
    )
    fabric.print(args)
    fabric.launch(main, args)


def main(fabric, args) -> None:

    if fabric.global_rank == 0:
        args.out_dir.mkdir(parents=True, exist_ok=True)

    fabric.seed_everything(
        args.seed, workers=True
    )  # same seed for every process to init model (FSDP)
    # fabric.seed_everything(workers=True)  # each process gets a different seed (DDP)

    config = Config.from_name(args.model_name)
    # Hard code for vocabulary
    config.padded_vocab_size = 185138

    fabric.print(f"Loading model with {config.__dict__}")
    t0 = time.perf_counter()
    with fabric.init_module(empty_init=True):  # empty_init=False (DDP)
        model = GPT(config)
        model.apply(model._init_weights)

    fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
    fabric.print(f"Total parameters {num_parameters(model):,}")

    model = torch.compile(model)
    model = fabric.setup(model)
    

    
    no_decay = ["bias", "norm_1.weight", "norm_2.weight", "ln_f.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    num_decay_params = sum(p.numel() for p in optimizer_grouped_parameters[0]["params"])
    num_nodecay_params = sum(
        p.numel() for p in optimizer_grouped_parameters[1]["params"]
    )
    fabric.print(
        f"num decayed parameter tensors: {len(optimizer_grouped_parameters[0]['params'])}, with {num_decay_params:,} parameters"
    )
    fabric.print(
        f"num non-decayed parameter tensors: {len(optimizer_grouped_parameters[1]['params'])}, with {num_nodecay_params:,} parameters"
    )
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        betas=(args.beta1, args.beta2),
        foreach=True,
    )
    optimizer = fabric.setup_optimizers(optimizer)

    train_data, val_data = load_datasets(args.data_dir, block_size=model.config.block_size)
    train_dataloader = DataLoader(
        train_data, batch_size=args.micro_batch_size, num_workers=8
    )
    val_dataloader = DataLoader(val_data, batch_size=args.micro_batch_size, num_workers=8)
    # distributed sampler
    train_dataloader, val_dataloader = fabric.setup_dataloaders(
        train_dataloader, val_dataloader
    )

    state = {
        "model": model,
        "optimizer": optimizer,
        "config":args,
        "iter_num": 0,
        "step_count": 0,
    }
    # Create a unique W&B ID
    wandb_id = f"{args.out_dir.parts[-2]}_{args.out_dir.parts[-1]}_huge"
    resume = find_resume_path(args.resume, args.out_dir)
    if resume:
        fabric.print(f"Resuming training from {resume}")
        fabric.load(resume, state)
        if not os.getenv("WANDB_DISABLED") and fabric.global_rank == 0:
            resume_mode = "must" 
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "xxxxxx"),
                entity=os.getenv("WANDB_ENTITY", "xxxxxx"),
                group=os.getenv("WANDB_GROUP", "xxxxxx"),
                name=os.getenv("WANDB_NAME", "xxxxxx"),
                config=args,
                id=wandb_id,
                resume=resume_mode
            )            
    else:
        # wandb logging
        if not os.getenv("WANDB_DISABLED") and fabric.global_rank == 0:
            wandb.init(project=os.getenv("WANDB_PROJECT", "xxxxxx"), entity=os.getenv("WANDB_ENTITY", "xxxxxx"),
                    group=os.getenv("WANDB_GROUP", "xxxxxx"),name=os.getenv("WANDB_NAME", "xxxxxx"), config=args,resume = "allow",id = wandb_id)    
    train_time = time.perf_counter()
    train(fabric, args, state, train_dataloader, val_dataloader)
    
    fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
    if fabric.device.type == "cuda":
        fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")


def train(fabric, args, state, train_dataloader, val_dataloader):
    model = state["model"]
    optimizer = state["optimizer"]

    validate(fabric, args,model, val_dataloader)  # sanity check

    with torch.device("meta"):
        meta_model = GPT(model.config)
        # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
        # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
        # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead
        estimated_flops = estimate_flops(meta_model) * args.micro_batch_size
        fabric.print(
            f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}"
        )
        ids = torch.randint(0, 1, (args.micro_batch_size, model.config.block_size))
        pos = torch.randint(0, 1, (args.micro_batch_size, model.config.block_size))
        measured_flops = measure_flops(meta_model, ids , pos)
        fabric.print(
            f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}"
        )
        del meta_model, ids, pos
        
    initial_iter = state["iter_num"]
    total_t0 = time.perf_counter()

    train_iter = iter(train_dataloader)
    progress_bar = tqdm(total=args.max_iters-state["iter_num"], desc="Pretraining", disable=not fabric.is_global_zero)
    # each GPU processes 1 / gpu_num data
    for state["iter_num"] in range(state["iter_num"], args.max_iters):
        # determine and set the learning rate for this iteration
        lr = get_lr(state["iter_num"],args.warmup_iters,args.stable_iters,args.decay_iters,args.learning_rate) if args.decay_lr else args.learning_rate
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        iter_t0 = time.perf_counter()

        input_ids, positions, targets = next(train_iter)

        is_accumulating = (state["iter_num"] + 1) % args.gradient_accumulation_steps != 0
        with fabric.no_backward_sync(model, enabled=is_accumulating):
            logits = model(input_ids,positions)
            loss = chunked_cross_entropy(logits, targets, chunk_size=0)
            fabric.backward(loss / args.gradient_accumulation_steps)

        # the gradient is globally averaged
        if not is_accumulating:
            # every GPU does same optimization (for aligning the model)
            global_grad_norm = fabric.clip_gradients(model, optimizer, max_norm=args.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            state["step_count"] += 1


        if state["iter_num"] % args.log_interval == 0 and state["iter_num"] != 0:
            t1 = time.perf_counter()
            metrics = {
                "train/loss": loss.item(),
                "iter": state["iter_num"],
                "iter_time": t1 - iter_t0,
                "remaining_time": (
                    (t1 - total_t0) / (state["iter_num"]+1 - initial_iter) * (args.max_iters - state["iter_num"])
                ),
                "tokens": state["iter_num"] * args.micro_batch_size * model.config.block_size,
                "total_tokens": (state["iter_num"] * args.micro_batch_size * model.config.block_size * fabric.world_size),
                "learning_rate": lr,
                "train/global_grad_norm": global_grad_norm.item(),
            }
            progress_bar.update(args.log_interval)
            progress_bar.set_description(f'Loss: {loss.item()}')
            if not os.getenv("WANDB_DISABLED")  and fabric.global_rank == 0:
                wandb.log(
                    metrics,
                    step = state["step_count"]
                )

        if not is_accumulating and state["step_count"] % args.eval_interval == 0:
            t0 = time.perf_counter()
            # every GPU does validation
            val_loss = validate(fabric, args,model, val_dataloader)
            t1 = time.perf_counter() - t0
            if not os.getenv("WANDB_DISABLED")  and fabric.global_rank == 0:
                wandb.log(
                    {
                        "valid/loss": val_loss,
                        "val time": t1 * 1000
                    },
                    step = state["step_count"]
                )
            fabric.barrier()
            
        if not is_accumulating and state["step_count"] % args.save_interval == 0:
            checkpoint_dir = args.out_dir / f"step-{state['step_count']}-ckpt"
            os.makedirs(checkpoint_dir, exist_ok=True)
            fabric.print(f"Saving checkpoint to {str(checkpoint_dir)!r}")
            fabric.save(checkpoint_dir / "lit_model.pth", state)
            if fabric.global_rank == 0:
                config_path = checkpoint_dir / "lit_config.json"
                config_path.write_text(json.dumps(model.config.__dict__))

    progress_bar.close()
    
    val_loss = validate(fabric, args,model, val_dataloader)    
    if fabric.global_rank == 0:
        checkpoint_dir = args.out_dir / f"step-{state['step_count']}-ckpt"
        os.makedirs(checkpoint_dir, exist_ok=True)
        out_file = checkpoint_dir / "valid_final_loss.txt"
        with open(out_file, 'w') as f:
            f.write(f"{args.model_name}\t{val_loss}\n")        
        

@torch.no_grad()
def validate(
    fabric: L.Fabric,args, model: torch.nn.Module, val_dataloader: DataLoader
) -> torch.Tensor:
    fabric.print("Validating ...")
    model.eval()
    val_iter = iter(val_dataloader)

    losses = torch.zeros(args.eval_iters, device=fabric.device)
    eval_progress_bar = tqdm(total = args.eval_iters, desc=f"Rank {fabric.local_rank} Evaluating",disable=not fabric.is_global_zero)
    for k in range(args.eval_iters):
        input_ids, positions,targets = next(val_iter)
        logits = model(input_ids,positions)
        loss = chunked_cross_entropy(logits, targets, chunk_size=0)
        losses[k] = loss.item()
        if k % args.log_interval == 0:
            eval_progress_bar.update(args.log_interval)
    
    out = losses.mean()

    model.train()
    return out


def load_datasets(data_dir: Path, block_size: int):
    train_data = Dataset(str(data_dir / "train_ids.bin"),str(data_dir / "train_pos.bin"), block_size=block_size)
    val_data = Dataset(str(data_dir / "valid_ids.bin"),str(data_dir / "valid_pos.bin"), block_size=block_size)
    
    return train_data,val_data


class Dataset(IterableDataset):
    def __init__(self, data_file_ids: Path,data_file_pos: Path, block_size: int):
        super().__init__()
        self.data_file_ids = data_file_ids
        self.data_file_pos = data_file_pos
        self.block_size = block_size

    def __iter__(self):
        data_ids = np.memmap(self.data_file_ids, dtype=np.uint32, mode="r")
        data_pos = np.memmap(self.data_file_pos, dtype=np.uint32, mode="r")
        while True:
            i = torch.randint(len(data_ids) - self.block_size, (1,)).item()
            
            x_ids = torch.from_numpy((data_ids[i : i + self.block_size]).astype(np.int64))
            x_pos = torch.from_numpy((data_pos[i : i + self.block_size]).astype(np.int64))
            y = torch.from_numpy(
                (data_ids[i + 1 : i + 1 + self.block_size]).astype(np.int64)
            )
            yield x_ids,x_pos, y


# learning rate decay scheduler (wsd with warmup)
def get_lr(it: int , warmup_iters: int, stable_iters: int, decay_iters: int, learning_rate: float) -> float:
    # 1) Linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) Stable phase for stable_iters steps
    elif it < warmup_iters + stable_iters:
        return learning_rate
    # 3) Decay phase for remaining steps
    else:
        decay_factor = math.pow(0.5,(it - warmup_iters - stable_iters) / decay_iters)
        return learning_rate * decay_factor

if __name__ == "__main__":
    # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
    # torch.backends.cuda.enable_flash_sdp(False)
    torch.set_float32_matmul_precision("high")

    from jsonargparse import CLI
    CLI(setup)
