import logging
from pathlib import Path

import hydra
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
import wandb
from omegaconf import OmegaConf
from torch import nn
from torch.cuda.amp import GradScaler
from torch.distributed import destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import default_collate
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import v2
from tqdm import trange

from experiments.mixup import CutMix, MixUp
from experiments.utils import (
    count_parameters,
    ddp_setup,
    register_resolvers,
    set_logger,
    set_seed,
)

set_logger()
register_resolvers()


def forward_pass(cfg, model, inputs):
    if cfg.reuse_backbone:
        return model.forward_representations(inputs)
    else:
        return model(inputs)


@torch.no_grad()
def evaluate(model, loader, cfg, device, num_batches=None):
    model.eval()
    loss = 0.0
    correct = 0.0
    total = 0.0
    predicted, gt = [], []
    for i, batch in enumerate(loader):
        if num_batches is not None and i >= num_batches:
            break
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        out = forward_pass(cfg, model, inputs)
        loss += F.cross_entropy(out, targets, reduction="sum")
        total += len(targets)
        pred = out.argmax(1)
        correct += pred.eq(targets).sum()
        predicted.extend(pred.cpu().numpy().tolist())
        gt.extend(targets.cpu().numpy().tolist())

    model.train()
    avg_loss = loss / total
    avg_acc = correct / total

    return dict(avg_loss=avg_loss, avg_acc=avg_acc, predicted=predicted, gt=gt)


def train(cfg, hydra_cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    if cfg.seed is not None:
        set_seed(cfg.seed)

    rank = OmegaConf.select(cfg, "distributed.rank", default=0)

    if cfg.wandb.name is None:
        model_name = cfg.model._target_.split(".")[-1]
        cfg.wandb.name = (
            f"{cfg.data.dataset_name}_neoneural_representations_{model_name}"
            f"_seed_{cfg.seed}"
        )
    if rank == 0:
        wandb.init(
            **OmegaConf.to_container(cfg.wandb, resolve=True),
            settings=wandb.Settings(start_method="fork"),
            config=OmegaConf.to_container(cfg, resolve=True),
        )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if rank == 0:
        logging.info(f"Using device {device}")

    # Load dataset
    train_set = hydra.utils.instantiate(cfg.data.train)
    val_set = hydra.utils.instantiate(cfg.data.val)
    test_set = hydra.utils.instantiate(cfg.data.test)

    if cfg.data.normalize:
        train_mu = train_set.dataset.tensors[0].mean(dim=0)
        train_std = train_set.dataset.tensors[0].std(dim=0)
        train_set.set_stats(train_mu, train_std)
        val_set.set_stats(train_mu, train_std)
        test_set.set_stats(train_mu, train_std)

    if cfg.mixup:
        mixup_aug = MixUp(alpha=cfg.alpha, num_classes=cfg.data.num_classes)
        # cutmix_aug = CutMix(alpha=cfg.beta, num_classes=cfg.data.num_classes)
        # cutmix_or_mixup = v2.RandomChoice([mixup_aug, cutmix_aug])
        # collate_fn = lambda batch: cutmix_or_mixup(*default_collate(batch))
        collate_fn = lambda batch: mixup_aug(*default_collate(batch))
    else:
        collate_fn = None

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=cfg.batch_size,
        shuffle=not cfg.distributed,
        num_workers=cfg.num_workers,
        pin_memory=True,
        sampler=DistributedSampler(train_set) if cfg.distributed else None,
        collate_fn=collate_fn,
    )
    train_eval_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=False,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=False,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )

    embedding_shapes = train_set[0][0].shape
    if rank == 0:
        logging.info(
            f"train size {len(train_set):_}, "
            f"val size {len(val_set):_}, "
            f"test size {len(test_set):_}"
        )
        logging.info(f"Embedding shapes: {embedding_shapes}")

    if cfg.reuse_backbone:
        # Load backbone model
        model = hydra.utils.instantiate(
            cfg.model, signals_to_fit=cfg.data.num_train_images
        ).to(device)
        if cfg.data.load_backbone:
            ckpt = torch.load(Path(cfg.data.dataset_dir) / cfg.data.load_backbone)
            model.load_state_dict(ckpt["model"])
            if rank == 0:
                logging.info(f"loaded checkpoint {cfg.data.load_backbone}")
            all_embeddings = torch.load(
                Path(cfg.data.dataset_dir) / cfg.data.dataset_file
            )[0]
            # Re-load train & test embeddings
            model.hidden_embeddings.data = torch.from_numpy(
                all_embeddings[:, : -cfg.model.num_outputs - 1]
            ).to(device)
            model.output_embedding.data = torch.from_numpy(
                all_embeddings[:, -cfg.model.num_outputs - 1 : -1]
            ).to(device)
            model.simclr_embedding.data = torch.from_numpy(all_embeddings[:, [-1]]).to(
                device
            )
        else:
            raise ValueError("No backbone model loaded")

        # Freeze embeddings
        model.hidden_embeddings.requires_grad_(False)
        model.output_embedding.requires_grad_(False)
        model.simclr_embedding.requires_grad_(False)

        for p in model.parameters():
            p.requires_grad = False
        model.add_cls_head(cfg.data.num_classes, True)
        model.to(device)
    else:
        if cfg.model._target_.split(".")[-1] == "NeoMLP":
            model = hydra.utils.instantiate(cfg.model).to(device)
        else:
            model = hydra.utils.instantiate(
                cfg.model, input_dim=embedding_shapes[1]
            ).to(device)

    if rank == 0:
        logging.info(
            f"Initialized model. Number of parameters {count_parameters(model):_}"
        )

    parameters = [p for p in model.parameters() if p.requires_grad]
    print([name for name, param in model.named_parameters() if param.requires_grad])
    optimizer = hydra.utils.instantiate(cfg.optim, params=parameters)
    if hasattr(cfg, "scheduler"):
        scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
    else:
        scheduler = None

    criterion = nn.CrossEntropyLoss()
    train_eval_acc = 0.0
    best_val_acc = 0.0
    best_test_results, best_val_results = None, None
    val_acc = 0.0
    val_loss = float("inf")
    test_acc = 0.0
    test_loss = float("inf")
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        if rank == 0:
            logging.info(f"loaded checkpoint {cfg.load_ckpt}")

    epoch_iter = trange(start_epoch, cfg.num_epochs, disable=rank != 0)
    if cfg.distributed:
        model = DDP(
            model, device_ids=cfg.distributed.device_ids, find_unused_parameters=False
        )
    model.train()

    if rank == 0:
        ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
        ckpt_dir.mkdir(parents=True, exist_ok=True)

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    optimizer.zero_grad()
    for epoch in epoch_iter:
        if cfg.distributed:
            train_loader.sampler.set_epoch(epoch)

        for idx, batch in enumerate(train_loader):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            with torch.autocast(**autocast_kwargs):
                outputs = forward_pass(cfg, model, inputs)
                loss = criterion(outputs, targets) / cfg.num_accum

            scaler.scale(loss).backward()
            log = {
                "train/loss": loss.item() * cfg.num_accum,
                "global_step": global_step,
            }
            wandb.log(log)

            if ((idx + 1) % cfg.num_accum == 0) or (idx + 1 == len(train_loader)):
                if cfg.clip_grad:
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        parameters, cfg.clip_grad_max_norm
                    )
                    log["grad_norm"] = grad_norm
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                if scheduler is not None:
                    log["lr"] = scheduler.get_last_lr()[0]
                    scheduler.step()

            if rank == 0:
                wandb.log(log)
                epoch_iter.set_description(
                    f"[{epoch} {idx+1}], train loss: {log['train/loss']:.3f}, "
                    f"train acc: {train_eval_acc:.3f}, val_loss: {val_loss:.3f}, "
                    f"val_acc: {val_acc:.3f}, test_loss: {test_loss:.3f}, "
                    f"test_acc: {test_acc:.3f}"
                )
            global_step += 1

            if rank == 0 and (global_step + 1) % cfg.eval_every == 0:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "cfg": cfg,
                        "global_step": global_step,
                    },
                    ckpt_dir / "latest.ckpt",
                )

                train_eval_dict = evaluate(model, train_eval_loader, cfg, device)
                val_dict = evaluate(model, val_loader, cfg, device)
                test_dict = evaluate(model, test_loader, cfg, device)
                train_eval_loss = train_eval_dict["avg_loss"]
                train_eval_acc = train_eval_dict["avg_acc"]
                val_loss = val_dict["avg_loss"]
                val_acc = val_dict["avg_acc"]
                test_loss = test_dict["avg_loss"]
                test_acc = test_dict["avg_acc"]

                best_val_criteria = val_acc >= best_val_acc

                if best_val_criteria:
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "epoch": epoch,
                            "cfg": cfg,
                            "global_step": global_step,
                        },
                        ckpt_dir / "best_val.ckpt",
                    )
                    best_val_acc = val_acc
                    best_train_eval_results = train_eval_dict
                    best_val_results = val_dict
                    best_test_results = test_dict

                log = {
                    "train_eval/loss": train_eval_loss,
                    "train_eval/acc": train_eval_acc,
                    "train_eval/best_loss": best_train_eval_results["avg_loss"],
                    "train_eval/best_acc": best_train_eval_results["avg_acc"],
                    "val/loss": val_loss,
                    "val/acc": val_acc,
                    "val/best_loss": best_val_results["avg_loss"],
                    "val/best_acc": best_val_results["avg_acc"],
                    "test/loss": test_loss,
                    "test/acc": test_acc,
                    "test/best_loss": best_test_results["avg_loss"],
                    "test/best_acc": best_test_results["avg_acc"],
                    "epoch": epoch,
                    "global_step": global_step,
                }

                wandb.log(log)


def train_ddp(rank, cfg, hydra_cfg):
    ddp_setup(rank, cfg.distributed.world_size)
    cfg.distributed.rank = rank
    train(cfg, hydra_cfg)
    destroy_process_group()


@hydra.main(config_path="configs", config_name="base", version_base=None)
def main(cfg):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    if cfg.distributed:
        mp.spawn(
            train_ddp,
            args=(cfg, hydra_cfg),
            nprocs=cfg.distributed.world_size,
            join=True,
        )
    else:
        train(cfg, hydra_cfg)


if __name__ == "__main__":
    main()
