import logging
from pathlib import Path

import h5py
import hydra
import torch
import torch.multiprocessing as mp
import wandb
from einops import rearrange
from omegaconf import OmegaConf
from torch import nn
from torch.cuda.amp import GradScaler
from torch.distributed import destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from tqdm import trange, tqdm

from experiments.neural_datasets.inr_utils import psnr
from experiments.neural_datasets.visualization import (
    wandb_embedding_interpolation,
    wandb_embeddings_table,
    wandb_image_grid,
)
from experiments.utils import (
    count_parameters,
    register_resolvers,
    ddp_setup,
    set_logger,
    set_seed,
)

set_logger()
register_resolvers()


def finetune(
    model,
    dataloader,
    hidden_embeddings,
    output_embeddings,
    simclr_embeddings,
    criterion,
    device,
    cfg,
):
    for param in model.parameters():
        param.requires_grad = False
    parameters = [hidden_embeddings, output_embeddings, simclr_embeddings]
    finetune_optimizer = hydra.utils.instantiate(cfg.optim, params=parameters)
    finetune_scaler = GradScaler(**cfg.gradscaler)

    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    model.train()
    finetune_optimizer.zero_grad()
    epoch_iter = trange(0, cfg.num_finetune_epochs)
    for epoch in epoch_iter:
        for idx, (inputs, ground_truth, indices) in enumerate(dataloader):
            inputs = inputs.to(device)
            ground_truth = ground_truth.to(device)
            indices = indices.to(device)

            with torch.autocast(**autocast_kwargs):
                outputs = model.forward_embeddings(
                    inputs,
                    hidden_embeddings=hidden_embeddings[indices],
                    output_embeddings=output_embeddings[indices],
                    simclr_embeddings=simclr_embeddings[indices],
                )
                loss = criterion(outputs, ground_truth) / cfg.num_accum

            finetune_scaler.scale(loss).backward()
            log = {
                "inference/loss": loss.item() * cfg.num_accum,
            }
            epoch_iter.set_description(
                f"[Epoch {epoch} {idx+1}/{len(dataloader)} "
                f"Inference loss: {log['inference/loss']:.3f}]"
            )

            if ((idx + 1) % cfg.num_accum == 0) or (idx + 1 == len(dataloader)):
                if cfg.clip_grad:
                    finetune_scaler.unscale_(finetune_optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        parameters, cfg.clip_grad_max_norm
                    )
                    log["grad_norm"] = grad_norm
                finetune_scaler.step(finetune_optimizer)
                finetune_scaler.update()
                finetune_optimizer.zero_grad()

            wandb.log(log)

            if (idx + 1) / len(dataloader) >= cfg.finetune_fraction:
                break


@torch.inference_mode()
def evaluate_full(
    model,
    dataloader,
    hidden_embeddings,
    output_embeddings,
    simclr_embeddings,
    criterion,
    device,
    cfg,
):
    model.eval()

    eval_losses = []
    eval_psnrs = []
    num_signals = 0
    all_eval_outputs = []

    for idx, (eval_inputs, eval_ground_truth, eval_indices) in enumerate(tqdm(dataloader)):
        eval_inputs = eval_inputs.to(device)
        eval_ground_truth = eval_ground_truth.to(device)
        eval_indices = eval_indices.to(device)

        eval_outputs = model.forward_embeddings(
            eval_inputs,
            hidden_embeddings=hidden_embeddings[eval_indices],
            output_embeddings=output_embeddings[eval_indices],
            simclr_embeddings=simclr_embeddings[eval_indices],
        )
        eval_loss = criterion(eval_outputs, eval_ground_truth)
        eval_losses.append(eval_loss.item() * eval_inputs.shape[0])

        eval_psnr = psnr(
            rearrange(
                eval_ground_truth,
                "(b h w) c -> b h w c",
                h=cfg.data.img_size[0],
                w=cfg.data.img_size[1],
            ),
            rearrange(
                eval_outputs,
                "(b h w) c -> b h w c",
                h=cfg.data.img_size[0],
                w=cfg.data.img_size[1],
            ),
        )
        all_eval_outputs.append(eval_outputs)
        eval_psnrs.append(eval_psnr.item() * eval_inputs.shape[0])
        num_signals += eval_inputs.shape[0]

    model.train()
    test_dict = dict(
        eval_loss=sum(eval_losses) / num_signals,
        eval_psnr=sum(eval_psnrs) / num_signals,
        eval_outputs=torch.cat(all_eval_outputs, 0),
    )
    return test_dict


@torch.no_grad()
def evaluate(model, dataloader, criterion, device, cfg):
    model.eval()

    eval_losses = []
    num_signals = 0
    all_eval_outputs = []
    all_eval_ground_truth = []

    num_images_processed = 0
    for eval_inputs, eval_ground_truth, eval_indices in dataloader:
        eval_inputs = eval_inputs.to(device)
        eval_ground_truth = eval_ground_truth.to(device)
        eval_indices = eval_indices.to(device)

        eval_outputs = model(eval_inputs, eval_indices)
        eval_loss = criterion(eval_outputs, eval_ground_truth)

        all_eval_outputs.append(eval_outputs)
        all_eval_ground_truth.append(eval_ground_truth)
        eval_losses.append(eval_loss.item() * eval_inputs.shape[0])
        num_signals += eval_inputs.shape[0]

        num_images_processed = num_signals / cfg.data.num_pixels
        if num_images_processed >= cfg.num_full_eval_images:
            break

    all_eval_outputs = torch.cat(all_eval_outputs, 0)
    all_eval_ground_truth = torch.cat(all_eval_ground_truth, 0)
    eval_psnr = psnr(
        rearrange(
            all_eval_ground_truth,
            "(b h w) c -> b h w c",
            h=cfg.data.img_size[0],
            w=cfg.data.img_size[1],
        ),
        rearrange(
            all_eval_outputs,
            "(b h w) c -> b h w c",
            h=cfg.data.img_size[0],
            w=cfg.data.img_size[1],
        ),
    )

    model.train()
    return dict(
        eval_loss=sum(eval_losses) / num_signals,
        eval_psnr=eval_psnr.item(),
        eval_outputs=all_eval_outputs,
    )


def train(cfg, hydra_cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    if cfg.seed is not None:
        set_seed(cfg.seed)

    rank = OmegaConf.select(cfg, "distributed.rank", default=0)

    if cfg.wandb.name is None:
        model_name = cfg.model._target_.split(".")[-1]
        cfg.wandb.name = (
            f"{cfg.data.dataset_name}_neural_datasets_{model_name}_seed_{cfg.seed}"
        )
    if rank == 0:
        wandb.init(
            **OmegaConf.to_container(cfg.wandb, resolve=True),
            settings=wandb.Settings(start_method="fork"),
            config=OmegaConf.to_container(cfg, resolve=True),
        )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if rank == 0:
        logging.info(f"Using device {device}")

    train_dataset = hydra.utils.instantiate(
        cfg.data.fit_train,
        train=True,
        use_test_transform=True,
    )

    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=int(cfg.num_train_images * cfg.data.num_pixels),
        # shuffle=not cfg.distributed,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        sampler=(
            DistributedSampler(train_dataset)
            if cfg.distributed
            else torch.utils.data.RandomSampler(train_dataset, replacement=True)
        ),
        drop_last=True,
    )

    train_eval_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=int(cfg.num_eval_images * cfg.data.num_pixels),
        shuffle=False,
    )

    if rank == 0:
        logging.info(f"Dataset size {len(train_dataset):_}")

    num_train_signals = train_dataset.num_images
    model = hydra.utils.instantiate(cfg.model, signals_to_fit=num_train_signals).to(
        device
    )

    if rank == 0:
        full_num_parameters = count_parameters(model)
        model_parameters = model.count_parameters()
        params_per_signal = (full_num_parameters - model_parameters) // num_train_signals
        logging.info(
            f"Initialized model. Number of parameters {full_num_parameters:_}, "
            f"Num signals: {num_train_signals:_}, Params per signal: {params_per_signal:_}"
        )
        logging.info(f"Number of parameters w/o embeddings: {model_parameters:_}")

    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = hydra.utils.instantiate(cfg.optim, params=parameters)

    if hasattr(cfg, "scheduler"):
        scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
    else:
        scheduler = None

    criterion = nn.MSELoss()
    eval_psnr = -1.0
    best_eval_psnr = -1.0
    eval_loss = float("inf")
    best_eval_results = None
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        if rank == 0:
            logging.info(f"loaded checkpoint {cfg.load_ckpt}")

    epoch_iter = trange(start_epoch, cfg.num_epochs, disable=rank != 0)
    if cfg.distributed:
        model = DDP(
            model, device_ids=cfg.distributed.device_ids, find_unused_parameters=False
        )
    model.train()

    if rank == 0 and cfg.save_ckpt:
        ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        logging.info(f"Saving checkpoints to {ckpt_dir}")

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    optimizer.zero_grad()
    for epoch in epoch_iter:
        if cfg.distributed:
            train_dataloader.sampler.set_epoch(epoch)

        for idx, (inputs, ground_truth, indices) in enumerate(train_dataloader):
            inputs = inputs.to(device)
            ground_truth = ground_truth.to(device)
            indices = indices.to(device)

            with torch.autocast(**autocast_kwargs):
                outputs = model(inputs, indices)
                loss = criterion(outputs, ground_truth) / cfg.num_accum

            scaler.scale(loss).backward()
            log = {
                "train/loss": loss.item() * cfg.num_accum,
                "global_step": global_step,
            }

            if ((idx + 1) % cfg.num_accum == 0) or (idx + 1 == len(train_dataloader)):
                if cfg.clip_grad:
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        parameters, cfg.clip_grad_max_norm
                    )
                    log["grad_norm"] = grad_norm
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                if scheduler is not None:
                    log["lr"] = scheduler.get_last_lr()[0]
                    scheduler.step()

            global_step += 1

            if rank == 0:
                wandb.log(log)
                epoch_iter.set_description(
                    f"[Epoch {epoch} {idx+1}/{len(train_dataloader)} "
                    f"# images {cfg.num_train_images * global_step}], "
                    f"Train loss: {log['train/loss']:.6f}, "
                    f"eval loss: {eval_loss:.3f}, eval PSNR: {eval_psnr:.3f}, "
                    f"best eval PSNR: {best_eval_psnr:.3f}"
                )

            if rank == 0 and global_step % cfg.steps_till_eval == 0:
                if cfg.save_ckpt:
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "epoch": epoch,
                            "cfg": cfg,
                            "global_step": global_step,
                        },
                        ckpt_dir / "latest.ckpt",
                    )

                eval_dict = evaluate(
                    model, train_eval_dataloader, criterion, device, cfg
                )
                eval_loss = eval_dict["eval_loss"]
                eval_psnr = eval_dict["eval_psnr"]
                eval_outputs = eval_dict["eval_outputs"]

                best_eval_criteria = eval_psnr >= best_eval_psnr

                if best_eval_criteria:
                    if cfg.save_ckpt:
                        torch.save(
                            {
                                "model": model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "epoch": epoch,
                                "cfg": cfg,
                                "global_step": global_step,
                            },
                            ckpt_dir / "best_eval.ckpt",
                        )

                    best_eval_psnr = eval_psnr
                    best_eval_results = eval_dict

                wdb_image_grid = wandb_image_grid(
                    eval_outputs,
                    cfg.num_full_eval_images,
                    cfg.data.img_size,
                    title=f"PNSR: {eval_psnr:.6f}",
                )

                # # Plot t-SNE of the embeddings
                # wdb_tsne_with_images = wandb_embeddings_table(
                #     model,
                #     dataset,
                #     random_tsne_indices,
                #     random_tsne_images,
                # )

                # # Plot interpolation of embeddings
                # wdb_embedding_interpolation = wandb_embedding_interpolation(
                #     model,
                #     cfg.data.img_size,
                # )

                log = {
                    "eval/loss": eval_loss,
                    "eval/psnr": eval_psnr,
                    "eval/best_loss": best_eval_results["eval_loss"],
                    "eval/best_psnr": best_eval_results["eval_psnr"],
                    "epoch": epoch,
                    "global_step": global_step,
                    # "Embedding Interpolation": wdb_embedding_interpolation,
                    "Reconstruction": wdb_image_grid,
                }
                wandb.log(log)

            if (idx + 1) / len(train_dataloader) >= cfg.train_fraction:
                break

    # Train embeddings from scratch, both for train and test

    ckpt = torch.load(ckpt_dir / "best_eval.ckpt")
    model.load_state_dict(ckpt["model"])
    if rank == 0:
        logging.info(f"loaded checkpoint {ckpt_dir / 'best_eval.ckpt'}")

    for name, param in model.named_parameters():
        param.requires_grad = False

    num_augmentations = cfg.num_augmentations
    if num_augmentations == 1:
        train_dataset = hydra.utils.instantiate(cfg.data.train, train=True)
    else:
        train_dataset = hydra.utils.instantiate(
            cfg.data.train,
            train=True,
            use_test_transform=False,
        )

    test_dataset = hydra.utils.instantiate(cfg.data.train, train=False)
    if rank == 0:
        logging.info(f"Test dataset size {len(test_dataset):_}")

    fit_train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=int(cfg.num_train_images * cfg.data.num_pixels),
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )
    train_eval_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=int(cfg.num_eval_images * cfg.data.num_pixels),
        shuffle=False,
    )

    f = h5py.File(ckpt_dir / 'best_eval_representations.h5', 'w')
    embeddings = f.create_dataset(
        "embeddings",
        (train_dataset.num_images * num_augmentations + test_dataset.num_images, 14, model.compressed_embed_dim),
        chunks=(1, 14, model.compressed_embed_dim),
    )
    labels = f.create_dataset(
        "labels",
        (train_dataset.num_images * num_augmentations + test_dataset.num_images,),
        dtype="i",
    )
    offset = 0
    for aug_idx in range(num_augmentations):
        train_dataset.transform_dataset()
        fit_train_dataloader.dataset.dset = train_dataset.dset
        train_eval_dataloader.dataset.dset = train_dataset.dset

        train_hidden_embeddings, train_output_embeddings, train_simclr_embeddings = (
            model.generate_embeddings(
                train_dataset.num_images,
                device=device,
                # hidden_mu=model.hidden_embeddings.detach().mean(),
                # hidden_sigma=model.hidden_embeddings.detach().std(),
                # out_mu=model.output_embedding.detach().mean(),
                # out_sigma=model.output_embedding.detach().std(),
            )
        )

        finetune(
            model,
            fit_train_dataloader,
            train_hidden_embeddings,
            train_output_embeddings,
            train_simclr_embeddings,
            criterion,
            device,
            cfg,
        )
        logging.info("Finished finetuning the embeddings")

        train_eval_dict = evaluate_full(
            model,
            train_eval_dataloader,
            train_hidden_embeddings,
            train_output_embeddings,
            train_simclr_embeddings,
            criterion,
            device,
            cfg,
        )
        logging.info(
            f"Train loss: {train_eval_dict['eval_loss']:.6f}, Train PSNR: {train_eval_dict['eval_psnr']:.6f}, "
        )

        # Save the learned image representations & labels
        train_embeddings = torch.cat(
            [train_hidden_embeddings, train_output_embeddings, train_simclr_embeddings],
            dim=1,
        )
        torch.save(
            [
                train_embeddings.detach().cpu().numpy(),
                train_dataset.labels.numpy(),
            ],
            ckpt_dir / f"best_eval_representations_{aug_idx}.pt",
        )
        embeddings[offset : offset + train_dataset.num_images] = train_embeddings.detach().cpu().numpy()
        labels[offset : offset + train_dataset.num_images] = train_dataset.labels.numpy()
        offset += train_dataset.num_images

    fit_test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=int(cfg.num_train_images * cfg.data.num_pixels),
        shuffle=True,
        num_workers=cfg.num_workers,
    )
    test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=int(cfg.num_eval_images * cfg.data.num_pixels),
        shuffle=False,
    )

    test_hidden_embeddings, test_output_embeddings, test_simclr_embeddings = (
        model.generate_embeddings(
            test_dataset.num_images,
            device=device,
            # hidden_mu=model.hidden_embeddings.detach().mean(),
            # hidden_sigma=model.hidden_embeddings.detach().std(),
            # out_mu=model.output_embedding.detach().mean(),
            # out_sigma=model.output_embedding.detach().std(),
        )
    )

    finetune(
        model,
        fit_test_dataloader,
        test_hidden_embeddings,
        test_output_embeddings,
        test_simclr_embeddings,
        criterion,
        device,
        cfg,
    )
    logging.info("Finished finetuning the test embeddings")
    test_dict = evaluate_full(
        model,
        test_dataloader,
        test_hidden_embeddings,
        test_output_embeddings,
        test_simclr_embeddings,
        criterion,
        device,
        cfg,
    )
    logging.info(
        f"Test loss: {test_dict['eval_loss']:.6f}, Test PSNR: {test_dict['eval_psnr']:.6f}"
    )
    test_embeddings = torch.cat(
        [test_hidden_embeddings, test_output_embeddings, test_simclr_embeddings], dim=1
    )
    torch.save(
        [
            test_embeddings.detach().cpu().numpy(),
            test_dataset.labels.numpy(),
        ],
        ckpt_dir / "best_eval_test_representations.pt",
    )
    embeddings[offset : offset + test_dataset.num_images] = test_embeddings.detach().cpu().numpy()
    labels[offset : offset + test_dataset.num_images] = test_dataset.labels.numpy()

    f.close()


def train_ddp(rank, cfg, hydra_cfg):
    ddp_setup(rank, cfg.distributed.world_size)
    cfg.distributed.rank = rank
    train(cfg, hydra_cfg)
    destroy_process_group()


@hydra.main(config_path="configs", config_name="base", version_base=None)
def main(cfg):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    if cfg.distributed:
        mp.spawn(
            train_ddp,
            args=(cfg, hydra_cfg),
            nprocs=cfg.distributed.world_size,
            join=True,
        )
    else:
        train(cfg, hydra_cfg)


if __name__ == "__main__":
    main()
