import math
import os
import random
from abc import ABC, abstractmethod
from collections import defaultdict

import fla  # noqa
import numpy as np
import torch
import wandb
from datasets import load_from_disk
from omegaconf import DictConfig, OmegaConf

# from fla.utils import print_master
from torch import distributed as dist
from torch.distributed import destroy_process_group, init_process_group
from torch.nn import CrossEntropyLoss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoConfig, AutoModelForCausalLM


def print_master(msg: str):
    """Prints only in master process if using multiple GPUs."""
    rank = os.environ.get("RANK", -1)
    ddp = int(rank) != -1
    master_process = (not ddp) or (int(rank) == 0)
    if master_process:
        print(msg)


def get_dataloaders(cfg: DictConfig):
    """Load trainset and perhaps validset. Returns correspondent DataLoaders."""

    train_set = load_from_disk(cfg.cluster.data_home)
    train_set = train_set.with_format("torch")

    train_sampler = _get_sampler(train_set, cfg)

    # only used with intra-document masking
    def collate_fn(batch):
        return {
            "input_ids": torch.stack([x["input_ids"] for x in batch], dim=0),
            "docs_lengths": [x["docs_lengths"].tolist() for x in batch],
        }

    return DataLoader(
        train_set,
        sampler=train_sampler,
        batch_size=cfg.trainer.micro_batch_size,
        num_workers=cfg.trainer.num_workers,
        pin_memory=True,
        prefetch_factor=2 if cfg.trainer.num_workers > 0 else None,
        persistent_workers=True if cfg.trainer.num_workers > 0 else False,
        collate_fn=collate_fn if "docs_lengths" in train_set.column_names else None,
    )


def _get_sampler(train_set, cfg: DictConfig):
    """Initlaizes a sampler for a torch.Dataloader.
    Options:
      - random sampler
      - sequential sampler
      - stateful random sampler
      - stateful sequential sampler
    We implement "stateful" sequential samplers for resuming training from a specified step.
    """
    ddp = dist.is_initialized()

    if ddp:
        sampler = DistributedSampler(train_set, shuffle=False, drop_last=True)
    else:
        sampler = SequentialSampler(train_set)

    return sampler


def log(
    cfg: DictConfig,
    metrics: dict,
    micro_step: int,
    train_loss: torch.Tensor,
    train_loss_array: list,
    optimizer: torch.optim.Optimizer,
    world_size: int,
    last_grnorm: float | torch.Tensor,
):
    """Update metrics, print to console, log on wandb."""

    if isinstance(train_loss_array, list):
        train_loss_avg = torch.stack(train_loss_array).mean().item()
    elif isinstance(train_loss_array, torch.Tensor):
        train_loss_avg = train_loss_array.item()

    new_metrics = {
        "micro_step": micro_step,
        "step": micro_step // cfg.trainer.grad_accumulation_steps,
        "tokens": micro_step
        * cfg.trainer.micro_batch_size
        * cfg.trainer.seq_len
        * world_size,
        "lr": optimizer.param_groups[0].get("lr", float("NaN")),
        "train/loss": train_loss.item(),
        "train/loss_avg": train_loss_avg,
        "train/ppl": math.exp(train_loss),
        "train/ppl_avg": math.exp(train_loss_avg),
        "train/grad_norm": last_grnorm.to("cpu").item()
        if not isinstance(last_grnorm, float)
        else last_grnorm,
    }

    for k, v in new_metrics.items():
        metrics[k].append(v)

    if cfg.logger.print_progress:
        msg = " | ".join(
            f"{key}: {value:.3e}" if isinstance(value, float) else f"{key}: {value}"
            for key, value in new_metrics.items()
        )
        print(msg)

    if cfg.logger.use_wandb:
        wandb.log(new_metrics)


def pytorch_setup(cfg):
    """Returns device, rank, seed, etc and initialize DDP"""
    ddp = int(os.environ.get("RANK", -1)) != -1  # check if DDP is enabled

    if ddp:
        init_process_group(backend="nccl")
        rank = int(os.environ["RANK"])
        local_rank = int(os.environ["LOCAL_RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        device = f"cuda:{local_rank}"
        # torch.cuda.device(device)
        torch.cuda.set_device(device)
        master_process = rank == 0
        seed_offset = rank
    else:
        master_process = True
        seed_offset = 0
        local_rank = None
        world_size = 1
        device = "cpu"
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"  # NOTE: macOS metal support to be tested

    random.seed(cfg.trainer.seed + seed_offset)
    np.random.seed(cfg.trainer.seed + seed_offset)
    torch.manual_seed(cfg.trainer.seed + seed_offset)

    # allow TF32, if not specified, we follow PyTorch 2.0 default
    # https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere
    torch.backends.cuda.matmul.allow_tf32 = getattr(
        cfg, "cuda_matmul_allow_tf32", False
    )
    torch.backends.cudnn.allow_tf32 = getattr(cfg, "cudnn_allow_tf32", True)

    return local_rank, world_size, device, master_process


def destroy_ddp():
    if torch.distributed.is_initialized():
        torch.cuda.synchronize()  # finish GPU work
        torch.distributed.barrier()  # wait for all ranks
        destroy_process_group()  # cleanly tear down comms


class CustomLRSchedule(ABC):
    """An abstract parent class for custom LR Schedules."""

    def __init__(self, optimizer):
        self.optimizer = optimizer

    def set_optim_lr(self, lr):
        """Set a learning rate for all parameter groups."""
        for group in self.optimizer.param_groups:
            group["lr"] = lr

    def state_dict(self):
        return {
            key: value for key, value in self.__dict__.items() if key != "optimizer"
        }

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    @abstractmethod
    def step(self):
        pass


class WarmupCosine(CustomLRSchedule):
    """Linear warmup followed by Cosine Decay."""

    def __init__(self, optimizer, lr_start, lr_max, lr_end, warmup_steps, T):
        super().__init__(optimizer)
        self.lr_start = lr_start
        self.lr_max = lr_max
        self.lr_end = lr_end
        self.warmup_steps = warmup_steps
        self.T = T
        self.iter = 0
        self.set_optim_lr(lr_start)

    def get_lr(self, t):
        """Computes and returns lr(t), where t is the current step."""
        if t <= self.warmup_steps:
            return self.lr_start + (self.lr_max - self.lr_start) / self.warmup_steps * t
        elif t <= self.T:
            progress = (t - self.warmup_steps) / (self.T - self.warmup_steps)
            return self.lr_end + 0.5 * (self.lr_max - self.lr_end) * (
                1 + math.cos(math.pi * progress)
            )
        return self.lr_end

    def step(self):
        self.iter += 1
        lr = self.get_lr(self.iter)
        self.set_optim_lr(lr)


def _move_to_device(batch, seq_len, device):
    """Slice batch to get inputs and targets, and move them to device."""

    inputs = batch["input_ids"][:, : seq_len - 1]
    targets = batch["input_ids"][:, 1:]

    attn_mask = None

    if "cuda" in device:
        # pin arrays allows to move them to GPU asynchronously (non_blocking=True)
        inputs = inputs.pin_memory().to(device, non_blocking=True)
        targets = targets.pin_memory().to(device, non_blocking=True)
    else:
        inputs, targets = inputs.to(device), targets.to(device)

    return inputs, targets, attn_mask


class TorchEngine(torch.nn.Module):
    """
    A module containing model, optimizer, scheduler, grad scaler.
    Wraps together a training step. Takes care of grad accumulation.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        cfg: DictConfig,
        device: torch.device,
        local_rank: int,
    ):
        super().__init__()

        self.micro_steps = 0
        self.accumulated_samples = 0

        self.seq_len = cfg.trainer.seq_len
        self.accumulation_steps = cfg.trainer.grad_accumulation_steps
        self.grad_clip = cfg.trainer.grad_clip
        self.dtype = cfg.trainer.dtype
        self.device = device

        # Move model to device and to DDP
        self.model = model.to(self.device)
        if torch.distributed.is_initialized():
            self.model = DDP(self.model, device_ids=[local_rank])

        # Compile
        if cfg.trainer.torch_compile:
            print("Compiling the model...")
            self.model = torch.compile(self.model, dynamic=True)

        # AMP
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

        # Grad scaler if training in fp16, if enabled=False, scaler is a no-op
        self.scaler = torch.amp.GradScaler(enabled=(self.dtype == "float16"))

        # Loss
        self.criterion = CrossEntropyLoss()

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=cfg.trainer.lr_start,
            betas=(cfg.trainer.beta1, cfg.trainer.beta2),
            weight_decay=cfg.trainer.weight_decay,
            eps=cfg.trainer.eps,
            fused=cfg.trainer.fused_optim,
        )

        # Scheduler
        self.scheduler = WarmupCosine(
            self.optimizer,
            lr_start=cfg.trainer.lr_start,
            lr_max=cfg.trainer.learning_rate,
            lr_end=cfg.trainer.lr_end,
            warmup_steps=cfg.trainer.scheduler.warmup_steps * cfg.trainer.steps,
            T=cfg.trainer.steps,
        )

    def step(self, batch):
        """Wraps a fwd pass, bwd pass, and optimization step."""

        self.model.train()

        self.micro_steps += 1
        self.accumulated_samples += 1

        inputs, targets, attn_mask = _move_to_device(batch, self.seq_len, self.device)

        # sync (reduce) gradients at the last accumulation step
        if torch.distributed.is_initialized():
            self.model.require_backward_grad_sync = (
                self.accumulated_samples == self.accumulation_steps
            )

        # forward pass with autocasting
        with self.ctx:
            output = self.model(inputs, attn_mask)
            logits = getattr(output, "logits", output)
            loss = self.criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss = loss / self.accumulation_steps

        # detach for logging (scale up to undo the division above)
        loss_val = loss.detach() * self.accumulation_steps
        if torch.isnan(loss_val):
            raise ValueError("Train loss is nan")

        # backward pass, with gradient scaling if training in fp16
        self.scaler.scale(loss).backward()

        # step after accumulation
        if self.accumulated_samples == self.accumulation_steps:
            self.accumulated_samples = 0

            if self.grad_clip:
                self.scaler.unscale_(self.optimizer)
                self.last_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.grad_clip
                )

            # step the optimizer, step the scaler if training in fp16
            self.scaler.step(self.optimizer)
            self.scaler.update()

            # flush the gradients
            self.optimizer.zero_grad(set_to_none=True)

            # step the scheduler
            if self.scheduler:
                self.scheduler.step()

        return loss_val


def main(cfg: DictConfig):
    # load the model

    model = AutoModelForCausalLM.from_config(
        AutoConfig.from_pretrained(cfg.trainer.model_name_or_path)
    )

    local_rank, world_size, device, master_process = pytorch_setup(cfg)

    if cfg.logger.use_wandb and master_process:
        wandb.init(
            project=cfg.logger.wandb_project,
            name=cfg.logger.wandb_run_name,
            dir=cfg.trainer.output_dir,
            config=OmegaConf.to_container(cfg),
        )

    trainloader = get_dataloaders(cfg)

    engine = TorchEngine(model, cfg, device, local_rank)

    print_master(f"Model num params: {model.num_parameters()}")

    # If we are just cooling down, we set budget = resume + cooldown
    steps_budget = cfg.trainer.steps
    micro_step_budget = steps_budget * cfg.trainer.grad_accumulation_steps
    if micro_step_budget > len(trainloader):
        raise ValueError("trainloader too short!")

    # Start the dataloader from the correct micro-batch
    step_start = 0
    micro_step_start = step_start * cfg.trainer.grad_accumulation_steps
    print_master(
        f"=== Start Training from step: {step_start}/{steps_budget}, micro_step: {micro_step_start}/{micro_step_budget} ==="
    )

    # Bookkeeping
    metrics = defaultdict(list)
    train_loss_array = []

    # Training
    for micro_step, micro_batch in enumerate(trainloader, micro_step_start + 1):
        step = micro_step // cfg.trainer.grad_accumulation_steps
        is_step = micro_step % cfg.trainer.grad_accumulation_steps == 0
        if step > steps_budget and is_step:
            break

        train_loss = engine.step(micro_batch)
        train_loss_array.append(train_loss)

        if master_process and step % cfg.trainer.logging_steps == 0 and is_step:
            log(
                cfg,
                metrics,
                micro_step,
                train_loss,
                train_loss_array,
                engine.optimizer,
                world_size,
                last_grnorm=engine.last_grad_norm,
            )
            train_loss_array = []

        # Checkpoint
        if (
            master_process
            and cfg.trainer.save_intermediate_checkpoints
            and step % cfg.trainer.save_every_steps == 0
            and is_step
        ):
            model.save_pretrained(cfg.trainer.output_dir + f"/step_{step}")

    # End of training: log and save checkpoint
    print_master("=== Training Completed! ===")
    if master_process and cfg.trainer.save_last_checkpoint:
        model.save_pretrained(cfg.trainer.output_dir + f"/step_{step}")

    # DDP slaughtering
    destroy_ddp()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--cfg", type=str, required=True)

    args = parser.parse_args()
    cfg = OmegaConf.load(args.cfg)

    main(cfg)
