from pathlib import Path
from pydoc import locate
from typing import Optional, Dict
import argparse
import sys
import logging
import os
from functools import partial
from einops import rearrange
from contextlib import contextmanager

from tqdm.auto import tqdm, trange
import hydra
import torch
import deepspeed
from deepspeed import get_accelerator
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.monitor.monitor import MonitorMaster
from omegaconf import DictConfig, OmegaConf
from deepspeed.runtime.activation_checkpointing import checkpointing

import gc
import diffusion
from diffusion.utils import (
    set_seed,
    NullObject,
    dict_to,
    save_zero_three_model,
    get_zero_three_statedict,
    get_scheduler,
)
from diffusion.model.modular import layers


@contextmanager
def freeze(*modules):
    """Temporarily turn off grads for one or more modules."""
    try:
        orig_states = []
        for module in modules:
            orig_states.append([p.requires_grad for p in module.parameters()])
            for p in module.parameters():
                p.requires_grad_(False)
        yield
    finally:
        for module, orig in zip(modules, orig_states):
            for p, flag in zip(module.parameters(), orig):
                p.requires_grad_(flag)


@hydra.main(config_path="configs", config_name="train", version_base="1.1")
def main(cfg: DictConfig):
    OmegaConf.resolve(cfg)

    set_seed(cfg.seed)

    local_rank = cfg.deepspeed_cli_args.local_rank
    if local_rank == -1:
        device = torch.device(get_accelerator().device_name())
    else:
        get_accelerator().set_device(local_rank)
        device = torch.device(get_accelerator().device_name(), local_rank)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        deepspeed.init_distributed()

    torch.distributed.barrier()
    global_rank = torch.distributed.get_rank()

    logger = logging.getLogger(f"{__name__}:rank{global_rank}")
    is_rank0 = global_rank in [-1, 0]
    rank0logger = logger if is_rank0 else NullObject()

    monitor = MonitorMaster(DeepSpeedConfig(OmegaConf.to_container(cfg.deepspeed.config)).monitor_config)
    if is_rank0 and cfg.wandb.enabled:
        import wandb

        # Init will be called by DeepSpeed
        wandb.config.update(OmegaConf.to_container(cfg))
    else:
        wandb = NullObject()

    with deepspeed.zero.Init(enabled=(cfg.deepspeed.config.zero_optimization.stage == 3)):
        # Instantiate constructs the models, so we have to call it in this context
        cfg = hydra.utils.instantiate(cfg)
        model = cfg.model.to(device)
    rank0logger.info(model)

    if not cfg.deepspeed.config.get("checkpointing", None) is None:
        from deepspeed.runtime.activation_checkpointing import checkpointing

        checkpointing.configure(mpu_=None, deepspeed_config=cfg.deepspeed.config)

    logger.info(f"Rank local {local_rank} / global {global_rank} initialized.")

    torch.distributed.barrier()
    set_seed(cfg.seed + global_rank)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    rank0logger.info(f"Total Param Count: {total_params / 1e6:.3f} M ({trainable_params / 1e6:.3f} M trainable)")

    os.environ["LOCAL_WORLD_SIZE"] = str(cfg.deepspeed.local_world_size)
    world_size = int(os.environ["WORLD_SIZE"])
    for varname in [
        "RANK",
        "LOCAL_WORLD_SIZE",
        "WORLD_SIZE",
        "MASTER_ADDR",
        "MASTER_PORT",
    ]:
        assert os.environ.get(varname) is not None, f"Rank {global_rank}: {varname} is not set"
        if not varname == "RANK":
            rank0logger.info(f"{varname}: {os.environ.get(varname)}")

    data = cfg.data
    dataloader_train = data.train

    torch.distributed.barrier()

    # Load checkpoint if available (variant 1: simple inference checkpoint)
    ckpt_path = cfg.get("model_ckpt_path", None)
    ckpt_tag = cfg.get("model_ckpt_tag", None)
    print(f"ckpt_path: {ckpt_path}")
    print(f"ckpt_tag: {ckpt_tag}")
    if ckpt_path is not None and ckpt_path.endswith("inference.pt"):
        assert not cfg.get("model_ckpt_load_optim", True), "Inference checkpoints do not contain optimizer states."
        assert not cfg.get(
            "model_ckpt_load_lr_scheduler", True
        ), "Inference checkpoints do not contain LR scheduler states."
        assert ckpt_tag is None, "Cannot specify tag for inference checkpoint."
        model.load_state_dict(
            torch.load(ckpt_path, map_location="cpu"),
            strict=cfg.get("model_ckpt_strict", True),
        )
        rank0logger.info(f"Inference checkpoint loaded!")

    optimizer: torch.optim.Optimizer = locate(cfg.optim_class)(
        model.parameters(), **OmegaConf.to_container(cfg.optim_params)
    )

    lr_scheduler = get_scheduler(
        name=cfg.lr_scheduler.name,
        optimizer=optimizer,
        num_warmup_steps=cfg.lr_scheduler.num_warmup_steps,
        num_training_steps=cfg.lr_scheduler.num_training_steps,
        scheduler_specific_kwargs=OmegaConf.to_container(cfg.lr_scheduler.scheduler_specific_kwargs),
    )

    ema: bool = cfg.ema.enabled
    ema_device = device if cfg.ema.device is None else torch.device(cfg.ema.device)
    ema_dtype = None if cfg.ema.dtype is None else locate(cfg.ema.dtype)
    if ema:
        # If loading an inference checkpoint anyways, initialize from the corresponding EMA if available
        state_dict = get_zero_three_statedict(
            model,
            global_rank=global_rank,
            zero_stage=cfg.deepspeed.config.zero_optimization.stage,
        )
        if ckpt_path is not None and ckpt_path.endswith("inference.pt"):
            ema_ckpt_path = Path(ckpt_path).parent / "ema.pt"
            if ema_ckpt_path.exists():
                state_dict = torch.load(ema_ckpt_path, map_location="cpu")
            else:
                if cfg.get("model_ckpt_strict", True):
                    raise Exception("No EMA in state dict.")
                else:
                    rank0logger.warning("No EMA in state dict.")
        if is_rank0:
            ema_model = cfg.ema.model
            ema_model.load_state_dict(state_dict, strict=cfg.get("model_ckpt_strict", True))
            ema_model.requires_grad_(False)
            ema_model = ema_model.to(ema_device)
            if not ema_dtype is None:
                ema_model = ema_model.to(ema_dtype)

    model.train()
    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        config_params=OmegaConf.to_container(cfg.deepspeed.config),
    )

    if cfg.deepspeed.config.bfloat16.enabled:
        dtype = torch.bfloat16
    elif not cfg.deepspeed.config.get("fp16") is None and cfg.deepspeed.config.fp16.get("enabled", False):
        dtype = torch.float16
    else:
        dtype = torch.float32
    rank0logger.info(f"Using dtype {dtype}.")

    # Load checkpoint if available (variant 2: DeepSpeed checkpoint)
    if ckpt_path is not None and not ckpt_path.endswith("inference.pt"):
        load_path, client_states = model_engine.load_checkpoint(
            load_dir=ckpt_path,
            tag=ckpt_tag,
            load_module_strict=cfg.get("model_ckpt_strict", True),
            load_optimizer_states=cfg.get("model_ckpt_load_optim", True),
            load_lr_scheduler_states=cfg.get("model_ckpt_load_lr_scheduler", True),
        )
        assert not load_path is None, "Checkpoint load failed."
        if ema and is_rank0:
            if "ema" in client_states:
                ema_model.load_state_dict(client_states["ema"], strict=cfg.get("model_ckpt_strict", True))
                ema_model.requires_grad_(False)
                ema_model = ema_model.to(ema_device)
            else:
                if cfg.get("model_ckpt_strict", True):
                    raise Exception("No EMA in state dict.")
                else:
                    rank0logger.warning("No EMA in state dict.")

        rank0logger.info(f"DeepSpeed checkpoint loaded!")

    i_epoch = -1
    stop = False
    max_steps: Optional[int] = cfg.max_steps

    val_freq: Optional[int] = cfg.val_freq
    if not val_freq is None:
        dataloader_val = data.validation
    max_val_steps: Optional[int] = cfg.max_val_steps
    checkpoint_freq: Optional[int] = cfg.checkpoint_freq
    save_inference_checkpoints: bool = cfg.save_inference_checkpoints

    profile: bool = cfg.profiling.enabled and is_rank0
    if profile:
        from torch.profiler import profile, ProfilerActivity, record_function

        profile_fn = partial(
            profile,
            activities=[
                *((ProfilerActivity.CPU,) if cfg.profiling.cpu else ()),
                *((ProfilerActivity.CUDA,) if cfg.profiling.cuda else ()),
            ],
            record_shapes=cfg.profiling.record_shapes,
            profile_memory=cfg.profiling.profile_memory,
            with_flops=cfg.profiling.with_flops,
            with_stack=True,
        )
        profile_ctx_fn = record_function
        profile_step: int = cfg.profiling.step
    else:
        profile_fn = NullObject()
        profile_ctx_fn = NullObject()

    i_step_local = 0
    loss_weights: Optional[Dict[str, float]] = cfg.loss_weights
    use_loss_weights: bool = not loss_weights is None
    with layers.checkpointing(enable=cfg.get("modular_checkpointing", False)):
        while not stop:  # Epochs
            i_epoch += 1
            for batch in (
                pbar := tqdm(
                    dataloader_train,
                    desc=f"Optimizing (Epoch {i_epoch + 1})",
                    disable=(not is_rank0),
                )
            ):
                if i_step_local % 1000 == 0 and i_step_local > 0:
                    gc.collect()
                    torch.cuda.empty_cache()
                    print(f"Epoch {i_epoch} - {i_step_local} - {torch.cuda.memory_allocated() / 1e6:.2f} MB")
                    print("Empty cache")

                with profile_fn() if profile and i_step_local == profile_step else NullObject() as prof:
                    with profile_ctx_fn(f"step_{model_engine.global_steps}/fwd"):
                        losses = model_engine(**dict_to(batch, device=device, dtype=dtype))
                        if isinstance(losses, tuple):
                            losses, metrics = losses
                            if not metrics is None:
                                monitor.write_events(
                                    [
                                        (
                                            f"Train/Samples/{k}",
                                            v.float(),
                                            model_engine.global_samples,
                                        )
                                        for k, v in metrics.items()
                                    ]
                                )
                        if isinstance(losses, dict):
                            if use_loss_weights:
                                assert set(losses) == set(
                                    loss_weights
                                ), f"Loss weights must match losses. {set(losses)} != {set(loss_weights)}"
                                loss = sum([(losses[k] * loss_weights[k]).mean() for k in losses])
                            else:
                                loss = sum(v.mean() for v in losses.values())
                            monitor.write_events(
                                [
                                    (
                                        f"Train/Samples/train_loss_{k}",
                                        losses[k].float().mean(),
                                        model_engine.global_samples,
                                    )
                                    for k in losses
                                ]
                            )
                        else:
                            assert not use_loss_weights, "Loss weights are only supported for dict losses."
                            loss = losses.mean()

                    with profile_ctx_fn(f"step_{model_engine.global_steps}/bwd"):
                        model_engine.backward(loss)

                    with profile_ctx_fn(f"step_{model_engine.global_steps}/step"):
                        model_engine.step()

                if profile and i_step_local == profile_step:
                    logger.info(f"Exporting profile to {Path(cfg.profiling.out_path).absolute()}")
                    prof.export_chrome_trace(cfg.profiling.out_path)

                monitor.write_events(
                    [
                        (
                            "Train/global_steps",
                            model_engine.global_steps,
                            model_engine.global_samples,
                        )
                    ]
                )

                if model_engine.was_step_applied():
                    monitor.write_events(
                        [
                            (
                                "Train/grad_norm",
                                model_engine.get_global_grad_norm(),
                                model_engine.global_samples,
                            )
                        ]
                    )

                    if ema:
                        with torch.no_grad():
                            state_dict = get_zero_three_statedict(
                                model,
                                global_rank=global_rank,
                                zero_stage=cfg.deepspeed.config.zero_optimization.stage,
                            )
                            if is_rank0:
                                ema_state_dict = ema_model.state_dict()
                                p_ema = [ema_state_dict[k].data for k in state_dict]
                                p = [state_dict[k].data.detach() for k in state_dict]
                                torch._foreach_mul_(p_ema, cfg.ema.decay)
                                torch._foreach_add_(p_ema, p, alpha=1 - cfg.ema.decay)
                                # # This only works if dtypes match, otherwise, an expensive cast is needed
                                # # The mul+add version doesn't care and just promotes dtypes
                                # # As we expect to primarily use bf16 training and keep an fp32 EMA, we'll use the other version
                                # torch._foreach_lerp_(p_ema, p, weight=1 - cfg.ema.decay)

                    if not checkpoint_freq is None and model_engine.global_steps % checkpoint_freq == 0:
                        Path(cfg.out_dir).mkdir(parents=True, exist_ok=True)
                        step_ckpt_dir = Path(cfg.out_dir) / f"step_{model_engine.global_steps}"
                        step_ckpt_dir.mkdir(parents=True, exist_ok=True)
                        if save_inference_checkpoints:
                            save_zero_three_model(
                                model_engine,
                                global_rank,
                                str(step_ckpt_dir / "inference.pt"),
                                cfg.deepspeed.config.zero_optimization.stage,
                            )
                            if ema and is_rank0:
                                torch.save(
                                    ema_model.state_dict(),
                                    str(step_ckpt_dir / "ema.pt"),
                                )
                        model_engine.save_checkpoint(
                            save_dir=str(cfg.out_dir),
                            tag=f"step_{model_engine.global_steps}",
                            client_state={
                                "global_step": model_engine.global_steps,
                            }
                            | ({"ema": ema_model.state_dict()} if ema and is_rank0 else {}),
                        )

                    if not val_freq is None and model_engine.global_steps % val_freq == 0:
                        model_engine.eval()
                        # If the module has a custom validate method, call that, otherwise, just do validation by computing the train loss on validation samples
                        validated = False
                        if hasattr(model_engine.module, "validate"):
                            try:
                                model_engine.module.validate(
                                    dataloader_val=dataloader_val,
                                    dataloader_train=dataloader_train,
                                    global_rank=global_rank,
                                    global_samples=model_engine.global_samples,
                                    max_steps=max_val_steps,
                                    device=device,
                                    dtype=dtype,
                                    monitor=monitor,
                                    wandb=wandb,
                                    ema_model=ema_model if ema and is_rank0 else None,
                                )
                                validated = True
                            except NotImplementedError:
                                pass
                        if not validated:
                            with torch.no_grad():
                                val_losses_accumulated = []
                                for i, val_batch in enumerate(
                                    tqdm(
                                        dataloader_val,
                                        desc=f"Validating",
                                        disable=(not is_rank0),
                                        total=max_val_steps,
                                    )
                                ):
                                    val_losses = model_engine.module(**dict_to(val_batch, device=device, dtype=dtype))
                                    if isinstance(val_losses, tuple):
                                        val_losses, metrics = val_losses
                                        if not metrics is None:
                                            monitor.write_events(
                                                [
                                                    (
                                                        f"Val/Samples/{k}",
                                                        v.float(),
                                                        model_engine.global_samples,
                                                    )
                                                    for k, v in metrics.items()
                                                ]
                                            )
                                    if isinstance(val_losses, dict):
                                        if use_loss_weights:
                                            assert set(val_losses) == set(
                                                loss_weights
                                            ), f"Loss weights must match losses. {set(losses)} != {set(loss_weights)}"
                                            val_loss = sum(
                                                [(val_losses[k] * loss_weights[k]).mean() for k in val_losses]
                                            )
                                        else:
                                            val_loss = sum(v.mean() for v in val_losses.values())
                                        monitor.write_events(
                                            [
                                                (
                                                    f"Val/Samples/val_loss_{k}",
                                                    val_losses[k].float(),
                                                    model_engine.global_samples,
                                                )
                                                for k in val_losses
                                            ]
                                        )
                                        val_losses_accumulated.append({})
                                        for k in val_losses:
                                            l = val_losses[k].mean()
                                            torch.distributed.all_reduce(l)
                                            val_losses_accumulated[-1][k] = (l / world_size).cpu().item()
                                        torch.distributed.all_reduce(val_loss)
                                        val_losses_accumulated[-1]["__overall__"] = (val_loss / world_size).cpu().item()
                                    else:
                                        assert not use_loss_weights, "Loss weights are only supported for dict losses."
                                        val_loss = val_losses.mean()
                                        torch.distributed.all_reduce(val_loss)
                                        val_losses_accumulated.append((val_loss / world_size).cpu().item())

                                    if max_val_steps is not None and i + 1 >= max_val_steps:
                                        break
                                if isinstance(val_losses_accumulated[-1], dict):
                                    val_loss_individual = {
                                        k: sum(v[k] for v in val_losses_accumulated) / len(val_losses_accumulated)
                                        for k in val_losses_accumulated[-1]
                                    }
                                    val_loss = val_loss_individual.pop("__overall__")
                                    monitor.write_events(
                                        [
                                            (
                                                f"Val/Samples/val_loss_{k}",
                                                val_loss_individual[k],
                                                model_engine.global_samples,
                                            )
                                            for k in val_loss_individual
                                        ]
                                    )
                                else:
                                    val_loss = sum(val_losses_accumulated) / len(val_losses_accumulated)
                                rank0logger.info(f"Validation loss: {val_loss}")
                                monitor.write_events(
                                    [
                                        (
                                            "Val/Samples/val_loss",
                                            val_loss,
                                            model_engine.global_samples,
                                        )
                                    ]
                                )

                        # put model into train mode
                        model_engine.train()

                if not max_steps is None and model_engine.global_steps == max_steps:
                    stop = True
                    break

                i_step_local += 1


if __name__ == "__main__":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    # torch.autograd.set_detect_anomaly(True)
    torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit)
    torch._dynamo.config.suppress_errors = True

    # Patch k-diffusion so that we get the full DeepSpeed checkpointing
    def checkpoint(function, *args, **kwargs):
        if layers.get_checkpointing():
            return checkpointing.checkpoint(function, *args, **kwargs)
        else:
            return function(*args, **kwargs)

    layers.checkpoint = checkpoint

    # Emable parsing DeepSpeed args despite using Hydra
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, help="Local rank for distributed training.")
    parser = deepspeed.add_config_arguments(parser)
    known_args, remaining_args = parser.parse_known_args(sys.argv[1:])
    sys.argv = (
        sys.argv[:1]
        + remaining_args
        + [f"++deepspeed_cli_args.{k}={v}" for k, v in known_args._get_kwargs() if not v is None]
    )

    if (
        "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ
        and "LOCAL_RANK" in os.environ
        and os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] != os.environ["LOCAL_RANK"]
    ):
        os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] = os.environ["LOCAL_RANK"]

    main()
