import tempfile
import logging
from pathlib import Path

import h5py
import hydra
import torch
import torch.multiprocessing as mp
import wandb
from einops import rearrange
from omegaconf import OmegaConf
from torch import nn
from torch.cuda.amp import GradScaler
from torch.distributed import destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from tqdm import trange, tqdm

from experiments.neural_datasets.inr_utils import iou
from experiments.utils import (
    count_parameters,
    register_resolvers,
    ddp_setup,
    set_logger,
    set_seed,
)
from experiments.neural_datasets.visualization import wandb_scatter_plot_gt_pred_sdf
from experiments.visualization_utils import extract_mesh_from_neural_field

set_logger()
register_resolvers()


def finetune(
    model,
    dataloader,
    hidden_embeddings,
    output_embeddings,
    simclr_embeddings,
    criterion,
    device,
    cfg,
):
    for param in model.parameters():
        param.requires_grad = False
    parameters = [hidden_embeddings, output_embeddings, simclr_embeddings]
    finetune_optimizer = hydra.utils.instantiate(cfg.optim, params=parameters)
    finetune_scaler = GradScaler(**cfg.gradscaler)

    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    model.train()
    finetune_optimizer.zero_grad()
    epoch_iter = trange(0, cfg.num_finetune_epochs)
    for epoch in epoch_iter:
        for idx, (inputs, ground_truth, indices) in enumerate(dataloader):
            inputs = inputs.to(device).float()
            ground_truth = ground_truth.to(device).float().unsqueeze(-1)
            indices = indices.to(device)

            with torch.autocast(**autocast_kwargs):
                outputs = model.forward_embeddings(
                    inputs,
                    hidden_embeddings=hidden_embeddings[indices],
                    output_embeddings=output_embeddings[indices],
                    simclr_embeddings=simclr_embeddings[indices],
                )
                loss = criterion(outputs, ground_truth) / cfg.num_accum

            finetune_scaler.scale(loss).backward()
            log = {
                "inference/loss": loss.item() * cfg.num_accum,
            }
            epoch_iter.set_description(
                f"[Epoch {epoch} {idx+1}/{len(dataloader)} "
                f"Inference loss: {log['inference/loss']:.3f}]"
            )

            if ((idx + 1) % cfg.num_accum == 0) or (idx + 1 == len(dataloader)):
                if cfg.clip_grad:
                    finetune_scaler.unscale_(finetune_optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        parameters, cfg.clip_grad_max_norm
                    )
                    log["grad_norm"] = grad_norm
                finetune_scaler.step(finetune_optimizer)
                finetune_scaler.update()
                finetune_optimizer.zero_grad()

            wandb.log(log)

            if (idx + 1) / len(dataloader) >= cfg.finetune_fraction:
                break


@torch.inference_mode()
def evaluate_full(
    model,
    dataloader,
    hidden_embeddings,
    output_embeddings,
    simclr_embeddings,
    criterion,
    device,
    cfg,
):
    model.eval()

    eval_losses = []
    eval_ious = []
    num_points = 0
    num_signals = 0

    for idx, (eval_inputs, eval_ground_truth, eval_indices) in enumerate(tqdm(dataloader)):
        eval_inputs = eval_inputs.to(device).float()
        eval_ground_truth = eval_ground_truth.to(device).float().unsqueeze(-1)
        eval_indices = eval_indices.to(device)

        eval_outputs = model.forward_embeddings(
            eval_inputs,
            hidden_embeddings=hidden_embeddings[eval_indices],
            output_embeddings=output_embeddings[eval_indices],
            simclr_embeddings=simclr_embeddings[eval_indices],
        )

        eval_loss = criterion(eval_outputs, eval_ground_truth)
        eval_losses.append(eval_loss.item() * eval_inputs.shape[0])
        num_points += eval_inputs.shape[0]

        eval_iou = iou(
            rearrange(
                eval_ground_truth,
                "(b n) c -> b n c",
                n=cfg.data.num_points,
            ),
            rearrange(
                eval_outputs,
                "(b n) c -> b n c",
                n=cfg.data.num_points,
            ),
        )

        num_existing_shapes = eval_ground_truth.shape[0] // cfg.data.num_points
        eval_ious.append(eval_iou.item() * num_existing_shapes)
        num_signals += num_existing_shapes

    model.train()
    test_dict = dict(
        eval_loss=sum(eval_losses) / num_points,
        eval_iou=sum(eval_ious) / num_signals,
    )
    return test_dict


@torch.no_grad()
def wandb_shapes(model, dataloader, criterion, device, cfg):
    # eval_inputs: N x 3 (coords)
    # eval_indices: N, can be different batches
    # eval_outputs: N x 1 (sdf)
    # eval_outputs = model(eval_inputs, eval_indices)

    points_batch_size = 32000
    apply_model_fn = model.forward
    shape_indices = cfg.data.plot_indices
    wandb_meshes = {}
    for shape_index in shape_indices:
        mesh = extract_mesh_from_neural_field(
            apply_model_fn,
            points_batch_size=points_batch_size,
            shape_index=shape_index,
            threshold=0,
            resolution0=32,
            upsampling_steps=3,  # 0 for lower training time
            padding=0.1,
            device=device,
        )
        mesh = mesh.scene()

        with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as tmp:
            mesh.export(tmp.name)
            wandb_mesh = wandb.Object3D(tmp.name, caption=f'shape_{shape_index}')
        wandb_meshes[f"eval/reconstruction/shape_{shape_index}"] = wandb_mesh
    return wandb_meshes


def wandb_scatter(eval_inputs, eval_ground_truth, eval_outputs, cfg):
    """Plot the point cloud on the surface"""
    eval_inputs = rearrange(
        eval_inputs,
        "(b n) c -> b n c",
        n=cfg.data.num_points,
    )
    eval_ground_truth = rearrange(
        eval_ground_truth,
        "(b n) c -> b n c",
        n=cfg.data.num_points,
    )
    eval_outputs = rearrange(
        eval_outputs,
        "(b n) c -> b n c",
        n=cfg.data.num_points,
    )
    sample_indices = [0, 1, 2, 3] if eval_inputs.shape[0] > 3 else eval_inputs.shape[0]
    figs = {}
    for sample_idx in sample_indices:
        sample_eval_inputs = eval_inputs[sample_idx]
        sample_eval_ground_truth = eval_ground_truth[sample_idx].squeeze(-1)
        sample_eval_outputs = eval_outputs[sample_idx].squeeze(-1)

        surface_coords = sample_eval_inputs[sample_eval_ground_truth < 0.0]
        surface_sdf = sample_eval_ground_truth[sample_eval_ground_truth < 0.0]

        recon_surface_coords = sample_eval_inputs[sample_eval_outputs < 0.0]
        recon_surface_sdf = sample_eval_outputs[sample_eval_outputs < 0.0]

        fig = wandb_scatter_plot_gt_pred_sdf(surface_coords, surface_sdf, recon_surface_coords, recon_surface_sdf)
        figs[f"eval/scatter_reconstruction/shape_{sample_idx}"] = fig
    return figs

    # # 2d central slice
    # central_slice_coords = jnp.linspace(-1, 1, 100)
    # central_slice_coords = jnp.stack(jnp.meshgrid(central_slice_coords, central_slice_coords, jnp.array([0.0])), axis=-1)
    # central_slice_coords = central_slice_coords.reshape(-1, 3)[None, ...]
    # central_slice_sdf = forward(enf_params, central_slice_coords, p[0:1], c[0:1], g[0:1], cur_rng)
    # central_slice_sdf = central_slice_sdf.reshape(100, 100)
    # plt.imshow(central_slice_sdf, cmap='coolwarm')
    # plt.colorbar()
    # wandb.log({"train/central_slice": wandb.Image(plt)}, step=global_step)
    # plt.close()


@torch.no_grad()
def evaluate(model, dataloader, criterion, device, cfg):
    model.eval()

    eval_losses = []
    num_signals = 0
    all_eval_outputs = []
    all_eval_ground_truth = []
    all_eval_inputs = []

    num_images_processed = 0
    for eval_inputs, eval_ground_truth, eval_indices in dataloader:
        eval_inputs = eval_inputs.to(device).float()
        eval_ground_truth = eval_ground_truth.to(device).float().unsqueeze(-1)
        eval_indices = eval_indices.to(device)

        eval_outputs = model(eval_inputs, eval_indices)
        eval_loss = criterion(eval_outputs, eval_ground_truth)

        all_eval_outputs.append(eval_outputs.cpu())
        all_eval_ground_truth.append(eval_ground_truth.cpu())
        all_eval_inputs.append(eval_inputs.cpu())

        eval_losses.append(eval_loss.item() * eval_inputs.shape[0])
        num_signals += eval_inputs.shape[0]

        num_images_processed = num_signals / cfg.data.num_points
        if num_images_processed >= cfg.num_full_eval_images:
            break

    all_eval_outputs = torch.cat(all_eval_outputs, 0)[: cfg.data.num_points * cfg.num_full_eval_images]
    all_eval_ground_truth = torch.cat(all_eval_ground_truth, 0)[: cfg.data.num_points * cfg.num_full_eval_images]
    all_eval_inputs = torch.cat(all_eval_inputs, 0)[: cfg.data.num_points * cfg.num_full_eval_images]
    eval_iou = iou(
        rearrange(
            all_eval_ground_truth,
            "(b n) c -> b n c",
            n=cfg.data.num_points,
        ),
        rearrange(
            all_eval_outputs,
            "(b n) c -> b n c",
            n=cfg.data.num_points,
        ),
    )

    model.train()
    return dict(
        eval_loss=sum(eval_losses) / num_signals,
        eval_iou=eval_iou.item(),
        eval_outputs=all_eval_outputs,
        eval_ground_truth=all_eval_ground_truth,
        eval_inputs=all_eval_inputs,
    )


def train(cfg, hydra_cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    if cfg.seed is not None:
        set_seed(cfg.seed)

    rank = OmegaConf.select(cfg, "distributed.rank", default=0)

    if cfg.wandb.name is None:
        model_name = cfg.model._target_.split(".")[-1]
        cfg.wandb.name = (
            f"{cfg.data.dataset_name}_neural_datasets_{model_name}_seed_{cfg.seed}"
        )
    if rank == 0:
        wandb.init(
            **OmegaConf.to_container(cfg.wandb, resolve=True),
            settings=wandb.Settings(start_method="fork"),
            config=OmegaConf.to_container(cfg, resolve=True),
        )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if rank == 0:
        logging.info(f"Using device {device}")

    train_dataset = hydra.utils.instantiate(
        cfg.data.fit_train,
        train=True,
    )

    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=cfg.num_train_images,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        sampler=(
            DistributedSampler(train_dataset)
            if cfg.distributed
            else torch.utils.data.RandomSampler(train_dataset, replacement=True)
        ),
        drop_last=True,
    )

    chunked_train_dataset = hydra.utils.instantiate(
        cfg.data.chunked_train,
        train=True,
    )

    def collate_fn(batch):
        batch = torch.utils.data.default_collate(batch)
        return batch[0].flatten(0, 1), batch[1].flatten(0, 1), batch[2].flatten(0, 1)

    train_eval_dataloader = torch.utils.data.DataLoader(
        dataset=chunked_train_dataset,
        batch_size=cfg.num_eval_images,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    if rank == 0:
        logging.info(f"Dataset size {len(train_dataset):_}")

    num_train_signals = train_dataset.num_signals
    model = hydra.utils.instantiate(cfg.model, signals_to_fit=num_train_signals).to(
        device
    )

    if rank == 0:
        full_num_parameters = count_parameters(model)
        model_parameters = model.count_parameters()
        params_per_signal = (full_num_parameters - model_parameters) // num_train_signals
        logging.info(
            f"Initialized model. Number of parameters {full_num_parameters:_}, "
            f"Num signals: {num_train_signals:_}, Params per signal: {params_per_signal:_}"
        )
        logging.info(f"Number of parameters w/o embeddings: {model_parameters:_}")

    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = hydra.utils.instantiate(cfg.optim, params=parameters)

    if hasattr(cfg, "scheduler"):
        scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
    else:
        scheduler = None

    criterion = nn.MSELoss()
    eval_iou = -1.0
    best_eval_iou = -1.0
    eval_loss = float("inf")
    best_eval_results = None
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        if rank == 0:
            logging.info(f"loaded checkpoint {cfg.load_ckpt}")

    epoch_iter = trange(start_epoch, cfg.num_epochs, disable=rank != 0)
    if cfg.distributed:
        model = DDP(
            model, device_ids=cfg.distributed.device_ids, find_unused_parameters=False
        )
    model.train()

    if rank == 0 and cfg.save_ckpt:
        ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        logging.info(f"Saving checkpoints to {ckpt_dir}")

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    optimizer.zero_grad()
    for epoch in epoch_iter:
        if cfg.distributed:
            train_dataloader.sampler.set_epoch(epoch)

        for idx, (inputs, ground_truth, indices) in enumerate(train_dataloader):
            inputs = inputs.to(device).float()
            ground_truth = ground_truth.to(device).float().unsqueeze(-1)
            indices = indices.to(device)

            with torch.autocast(**autocast_kwargs):
                outputs = model(inputs, indices)
                loss = criterion(outputs, ground_truth) / cfg.num_accum

            scaler.scale(loss).backward()
            log = {
                "train/loss": loss.item() * cfg.num_accum,
                "global_step": global_step,
            }

            if ((idx + 1) % cfg.num_accum == 0) or (idx + 1 == len(train_dataloader)):
                if cfg.clip_grad:
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        parameters, cfg.clip_grad_max_norm
                    )
                    log["grad_norm"] = grad_norm
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                if scheduler is not None:
                    log["lr"] = scheduler.get_last_lr()[0]
                    scheduler.step()

            global_step += 1

            if rank == 0:
                wandb.log(log)
                epoch_iter.set_description(
                    f"[Epoch {epoch} {idx+1}/{len(train_dataloader)} "
                    f"# signals {cfg.num_train_images / cfg.data.num_points * global_step:.3f}], "
                    f"Train loss: {log['train/loss']:.6f}, "
                    f"eval loss: {eval_loss:.3f}, eval IoU: {eval_iou:.3f}, "
                    f"best eval IoU: {best_eval_iou:.3f}"
                )

            if rank == 0 and global_step % cfg.steps_till_eval == 0:
                if cfg.save_ckpt:
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "epoch": epoch,
                            "cfg": cfg,
                            "global_step": global_step,
                        },
                        ckpt_dir / "latest.ckpt",
                    )

                eval_dict = evaluate(
                    model, train_eval_dataloader, criterion, device, cfg
                )
                eval_loss = eval_dict["eval_loss"]
                eval_iou = eval_dict["eval_iou"]

                best_eval_criteria = eval_iou >= best_eval_iou

                if best_eval_criteria:
                    if cfg.save_ckpt:
                        torch.save(
                            {
                                "model": model.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "epoch": epoch,
                                "cfg": cfg,
                                "global_step": global_step,
                            },
                            ckpt_dir / "best_eval.ckpt",
                        )

                    best_eval_iou = eval_iou
                    best_eval_results = eval_dict

                wdb_shapes = wandb_shapes(model, train_eval_dataloader, criterion,
                                          device, cfg)

                wdb_scatter_shapes = wandb_scatter(
                    eval_dict["eval_inputs"],
                    eval_dict["eval_ground_truth"],
                    eval_dict["eval_outputs"],
                    cfg,
                )

                log = {
                    "eval/loss": eval_loss,
                    "eval/iou": eval_iou,
                    "eval/best_loss": best_eval_results["eval_loss"],
                    "eval/best_iou": best_eval_results["eval_iou"],
                    "epoch": epoch,
                    "global_step": global_step,
                }
                log |= wdb_shapes
                log |= wdb_scatter_shapes
                wandb.log(log)

            if (idx + 1) / len(train_dataloader) >= cfg.train_fraction:
                break

    # Train embeddings from scratch, both for train and test

    ckpt = torch.load(ckpt_dir / "best_eval.ckpt")
    model.load_state_dict(ckpt["model"])
    if rank == 0:
        logging.info(f"loaded checkpoint {ckpt_dir / 'best_eval.ckpt'}")

    for name, param in model.named_parameters():
        param.requires_grad = False

    num_augmentations = 1
    # num_augmentations = cfg.num_augmentations
    # if num_augmentations == 1:
    #     train_dataset = hydra.utils.instantiate(cfg.data.train, train=True)
    # else:
    #     # NOTE: Identical for now
    #     train_dataset = hydra.utils.instantiate(
    #         cfg.data.train,
    #         train=True,
    #     )

    test_dataset = hydra.utils.instantiate(cfg.data.train, train=False)
    chunked_test_dataset = hydra.utils.instantiate(
        cfg.data.chunked_train,
        train=False,
    )
    if rank == 0:
        logging.info(f"Test dataset size {len(test_dataset):_}")

    fit_train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=cfg.num_train_images,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        sampler=torch.utils.data.RandomSampler(train_dataset, replacement=True),
    )

    train_eval_dataloader = torch.utils.data.DataLoader(
        dataset=chunked_train_dataset,
        batch_size=cfg.num_eval_images,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    f = h5py.File(ckpt_dir / 'best_eval_representations.h5', 'w')
    embeddings = f.create_dataset(
        "embeddings",
        (train_dataset.num_signals * num_augmentations + test_dataset.num_signals, 13, model.compressed_embed_dim),
        chunks=(1, 13, model.compressed_embed_dim),
    )
    labels = f.create_dataset(
        "labels",
        (train_dataset.num_signals * num_augmentations + test_dataset.num_signals,),
        dtype="i",
    )
    offset = 0
    for aug_idx in range(num_augmentations):
        if num_augmentations > 1:
            train_dataset.transform_dataset()
            fit_train_dataloader.dataset.dset = train_dataset.dset
            train_eval_dataloader.dataset.dset = train_dataset.dset

        train_hidden_embeddings, train_output_embeddings, train_simclr_embeddings = (
            model.generate_embeddings(
                train_dataset.num_signals,
                device=device,
                # hidden_mu=model.hidden_embeddings.detach().mean(),
                # hidden_sigma=model.hidden_embeddings.detach().std(),
                # out_mu=model.output_embedding.detach().mean(),
                # out_sigma=model.output_embedding.detach().std(),
            )
        )

        finetune(
            model,
            fit_train_dataloader,
            train_hidden_embeddings,
            train_output_embeddings,
            train_simclr_embeddings,
            criterion,
            device,
            cfg,
        )
        logging.info("Finished finetuning the embeddings")

        train_eval_dict = evaluate_full(
            model,
            train_eval_dataloader,
            train_hidden_embeddings,
            train_output_embeddings,
            train_simclr_embeddings,
            criterion,
            device,
            cfg,
        )
        logging.info(
            f"Train loss: {train_eval_dict['eval_loss']:.6f}, Train IoU: {train_eval_dict['eval_iou']:.6f}, "
        )

        # Save the learned image representations & labels
        train_embeddings = torch.cat(
            [train_hidden_embeddings, train_output_embeddings, train_simclr_embeddings],
            dim=1,
        )
        torch.save(
            [
                train_embeddings.detach().cpu().numpy(),
                train_dataset.labels,
            ],
            ckpt_dir / f"best_eval_representations_{aug_idx}.pt",
        )
        embeddings[offset : offset + train_dataset.num_signals] = train_embeddings.detach().cpu().numpy()
        labels[offset : offset + train_dataset.num_signals] = train_dataset.labels
        offset += train_dataset.num_signals

    fit_test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=cfg.num_train_images,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        sampler=torch.utils.data.RandomSampler(test_dataset, replacement=True),
    )
    test_dataloader = torch.utils.data.DataLoader(
        dataset=chunked_test_dataset,
        batch_size=cfg.num_eval_images,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    test_hidden_embeddings, test_output_embeddings, test_simclr_embeddings = (
        model.generate_embeddings(
            test_dataset.num_signals,
            device=device,
            # hidden_mu=model.hidden_embeddings.detach().mean(),
            # hidden_sigma=model.hidden_embeddings.detach().std(),
            # out_mu=model.output_embedding.detach().mean(),
            # out_sigma=model.output_embedding.detach().std(),
        )
    )

    finetune(
        model,
        fit_test_dataloader,
        test_hidden_embeddings,
        test_output_embeddings,
        test_simclr_embeddings,
        criterion,
        device,
        cfg,
    )
    logging.info("Finished finetuning the test embeddings")
    test_dict = evaluate_full(
        model,
        test_dataloader,
        test_hidden_embeddings,
        test_output_embeddings,
        test_simclr_embeddings,
        criterion,
        device,
        cfg,
    )
    logging.info(
        f"Test loss: {test_dict['eval_loss']:.6f}, Test IoU: {test_dict['eval_iou']:.6f}"
    )
    test_embeddings = torch.cat(
        [test_hidden_embeddings, test_output_embeddings, test_simclr_embeddings], dim=1
    )
    torch.save(
        [
            test_embeddings.detach().cpu().numpy(),
            test_dataset.labels,
        ],
        ckpt_dir / "best_eval_test_representations.pt",
    )
    embeddings[offset : offset + test_dataset.num_signals] = test_embeddings.detach().cpu().numpy()
    labels[offset : offset + test_dataset.num_signals] = test_dataset.labels

    f.close()


def train_ddp(rank, cfg, hydra_cfg):
    ddp_setup(rank, cfg.distributed.world_size)
    cfg.distributed.rank = rank
    train(cfg, hydra_cfg)
    destroy_process_group()


@hydra.main(config_path="configs", config_name="base_block", version_base=None)
def main(cfg):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    if cfg.distributed:
        mp.spawn(
            train_ddp,
            args=(cfg, hydra_cfg),
            nprocs=cfg.distributed.world_size,
            join=True,
        )
    else:
        train(cfg, hydra_cfg)


if __name__ == "__main__":
    main()

