import logging
from pathlib import Path

import hydra
import torch
import wandb
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import trange

from experiments.high_res_inr.calculus_utils import gradient, laplace
from experiments.neural_datasets.inr_utils import psnr
from experiments.utils import count_parameters, register_resolvers, set_logger, set_seed
from experiments.high_res_inr.dataloader import MultiEpochsDataLoader

set_logger()
register_resolvers()


def evaluate(model, dataloader, dataset_type, device, dataset, compute_grad=False):
    model.eval()
    eval_losses = []
    num_points = 0
    compute_grad = compute_grad and dataset_type == "image"
    outs = [[], [], []] if compute_grad else []
    for eval_model_input, eval_ground_truth, eval_mask in dataloader:
        eval_model_input = eval_model_input.to(device)
        eval_ground_truth = eval_ground_truth.to(device)

        coords = eval_model_input.clone().detach().requires_grad_(True)
        eval_model_output = model(coords)
        if dataset_type == "video_audio":
            eval_mask = eval_mask.to(device)
            eval_loss = (
                (eval_model_output - eval_ground_truth)[eval_mask][
                    :, : dataset.num_channels
                ]
                ** 2
            ).sum() + dataset.video_to_audio_ratio * (
                (eval_model_output - eval_ground_truth)[~eval_mask][
                    :, dataset.num_channels :
                ]
                ** 2
            ).sum()
        else:
            eval_loss = ((eval_model_output - eval_ground_truth) ** 2).sum()
        eval_losses.append(eval_loss.item())
        num_points += eval_model_output.shape[0]

        if not compute_grad:
            eval_model_output = eval_model_output.detach()
            outs.append(eval_model_output)
        else:
            eval_img_grad = gradient(eval_model_output, coords).detach()
            eval_img_laplacian = laplace(eval_model_output, coords).detach()
            eval_model_output = eval_model_output.detach()
            outs[0].append(eval_model_output)
            outs[1].append(eval_img_grad)
            outs[2].append(eval_img_laplacian)

    if not compute_grad:
        eval_model_output = torch.cat(outs, dim=0)
    else:
        eval_model_output = torch.cat(outs[0], dim=0)
        eval_img_grad = torch.cat(outs[1], dim=0).norm(dim=-1, keepdim=True)
        eval_img_laplacian = torch.cat(outs[2], dim=0)

    eval_loss = sum(eval_losses) / num_points

    if dataset_type == "video_audio":
        normalized_gt_video = dataset.inv_normalize_video(
            dataset.data[dataset.video_audio_mask][:, : dataset.num_channels]
        )
        normalized_gt_audio = dataset.inv_normalize_audio(
            dataset.data[~dataset.video_audio_mask][:, dataset.num_channels :]
        )
        normalized_model_output_video = dataset.inv_normalize_video(
            eval_model_output[dataset.video_audio_mask][:, : dataset.num_channels]
        )
        normalized_model_output_audio = dataset.inv_normalize_audio(
            eval_model_output[~dataset.video_audio_mask][:, dataset.num_channels :]
        )

        video_psnr = psnr(
            normalized_gt_video.to(device)[None, ...],
            normalized_model_output_video[None, ...],
        ).item()
        audio_psnr = psnr(
            normalized_gt_audio.to(device)[None, ...],
            normalized_model_output_audio[None, ...],
        ).item()
        logging.info(f"Video PSNR: {video_psnr:.3f} Audio PSNR: {audio_psnr:.3f}")

        signal_psnr = (video_psnr + audio_psnr) / 2
    else:
        normalized_gt = dataset.inv_normalize(dataset.data)
        normalized_model_output = dataset.inv_normalize(eval_model_output)
        signal_psnr = psnr(
            normalized_gt.to(device)[None, ...], normalized_model_output[None, ...]
        ).item()

    model.train()
    outputs = {
        "loss": eval_loss,
        "output": eval_model_output,
        "psnr": signal_psnr,
    }
    if dataset_type == "video_audio":
        outputs["video_psnr"] = video_psnr
        outputs["audio_psnr"] = audio_psnr
    if compute_grad:
        outputs["gradient"] = eval_img_grad
        outputs["laplacian"] = eval_img_laplacian

    return outputs


def train(cfg, hydra_cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    if cfg.seed is not None:
        set_seed(cfg.seed)

    if cfg.wandb.name is None:
        model_name = cfg.model._target_.split(".")[-1]
        cfg.wandb.name = (
            f"{cfg.data.dataset_name}_high_resolution_inr_{model_name}_seed_{cfg.seed}"
        )
    wandb.init(
        **OmegaConf.to_container(cfg.wandb, resolve=True),
        settings=wandb.Settings(start_method="fork"),
        config=OmegaConf.to_container(cfg, resolve=True),
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device {device}")

    dataset_type = cfg.data.dataset_type
    dataset = hydra.utils.instantiate(cfg.data.dataset)
    logging.info(f"Dataset size {len(dataset):_}")

    batch_size = cfg.data.batch_size if not cfg.pre_load_to_device else len(dataset)
    dataloader = MultiEpochsDataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True if not cfg.pre_load_to_device else False,
        num_workers=cfg.num_workers,
        pin_memory=not cfg.pre_load_to_device,
    )
    if cfg.pre_load_to_device:
        model_input, ground_truth, mask = next(iter(dataloader))
        model_input, ground_truth, mask = model_input.to(device), ground_truth.to(device), mask.to(device)
        if not cfg.pre_load_shuffle:
            dataloader = [(model_input, ground_truth, mask)]

    eval_dataloader = DataLoader(
        dataset,
        batch_size=cfg.data.eval_batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )

    model = hydra.utils.instantiate(cfg.model)
    logging.info(model)

    model.to(device)

    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = hydra.utils.instantiate(cfg.optim, parameters)

    if hasattr(cfg, "scheduler"):
        scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
    else:
        scheduler = None

    criterion = torch.nn.MSELoss()

    logging.info("-------------------")
    logging.info(
        f"Initialized model. Number of parameters {count_parameters(model):_}"
    )
    for name, param in model.named_parameters():
        logging.info(f"{name}: {param.shape}, {param.numel()}")
    logging.info("-------------------")

    output_dir = Path(f"outputs/{dataset_type}")
    output_dir.mkdir(exist_ok=True, parents=True)
    # dataset.savefigs(dataset.data, output_dir / "gt_data")

    signal_psnr = -1.0
    best_eval_psnr = -1.0
    eval_loss = float("inf")
    global_step = 0
    epoch_iter = trange(0, cfg.data.num_epochs)

    for epoch in epoch_iter:
        # dataloader = DataLoader(
        #     dataset,
        #     batch_size=min(cfg.data.batch_size, len(dataset)),
        #     # batch_size=min(2 ** (epoch // 50) * cfg.data.batch_size, len(dataset)),
        #     shuffle=True,
        #     # shuffle=False,
        #     # sampler=torch.utils.data.RandomSampler(dataset, replacement=True),
        #     num_workers=cfg.num_workers,
        #     pin_memory=True and not cfg.pre_load_to_device,
        #     # drop_last=True,
        # )

        if cfg.pre_load_to_device and cfg.pre_load_shuffle:
            perm = torch.randperm(len(dataset)).to(device)
            dataloader = [(model_input[perm], ground_truth[perm], mask[perm])]

        for idx, (model_input, ground_truth, mask) in enumerate(dataloader):
            model.train()
            model_input, ground_truth = model_input.to(device), ground_truth.to(device)
            model_output = model(model_input)
            if dataset_type == "video_audio":
                mask = mask.to(device)
                loss = criterion(
                    model_output[mask][:, : dataset.num_channels],
                    ground_truth[mask][:, : dataset.num_channels],
                ) + dataset.video_to_audio_ratio * criterion(
                    model_output[~mask][:, dataset.num_channels :],
                    ground_truth[~mask][:, dataset.num_channels :],
                )
            else:
                loss = criterion(model_output, ground_truth)

            optimizer.zero_grad()
            loss.backward()

            log = {
                "train/loss": loss.item(),
                "global_step": global_step,
                "epoch": epoch,
            }

            if cfg.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters, cfg.clip_grad_max_norm
                )
                log["grad_norm"] = grad_norm
            if cfg.log_weight_norms:
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        log[f"grad_norm/{name}"] = param.grad.norm().item()
                        log[f"lars/{name}"] = param.norm().item() / (
                            param.grad.norm().item() + 1e-8
                        )
                        log[f"weight_norm/{name}"] = param.norm().item()

            optimizer.step()

            if scheduler is not None:
                log["lr"] = scheduler.get_last_lr()[0]
                scheduler.step()

            wandb.log(log)
            epoch_iter.set_description(
                f"[Epoch {epoch} {idx+1}/{len(dataloader)}] "
                f"Train loss: {log['train/loss']:.3f}, "
                f"eval loss: {eval_loss:.3f}, eval PSNR: {signal_psnr:.3f}, "
                f"best eval PSNR: {best_eval_psnr:.3f}"
            )

            if global_step % cfg.data.steps_till_eval == 0:
                eval_results = evaluate(
                    model, eval_dataloader, dataset_type, device, dataset,
                    compute_grad=cfg.compute_grad,
                )

                if dataset_type == "image":
                    # dataset.savefigs(eval_model_output, eval_img_grad, eval_img_laplacian, output_dir / f"fit_{epoch}")
                    wandb_log = dataset.wandb(
                        eval_results["output"],
                        eval_results.get("gradient", None),
                        eval_results.get("laplacian", None),
                        dataset_type,
                    )
                else:
                    # dataset.savefigs(eval_model_output, output_dir / f"fit_{epoch}")
                    wandb_log = dataset.wandb(eval_results["output"], dataset_type)

                best_eval_criteria = eval_results["psnr"] >= best_eval_psnr
                if best_eval_criteria:
                    best_eval_psnr = eval_results["psnr"]

                log = {
                    "eval/loss": eval_results["loss"],
                    "eval/psnr": eval_results["psnr"],
                    # "eval/best_loss": best_eval_results["eval_loss"],
                    "eval/best_psnr": best_eval_psnr,
                    "global_step": global_step,
                }
                if dataset_type == "video_audio":
                    log["eval/video_psnr"] = eval_results["video_psnr"]
                    log["eval/audio_psnr"] = eval_results["audio_psnr"]
                log |= wandb_log

                wandb.log(log)

            global_step += 1

@hydra.main(config_path="configs", config_name="base", version_base=None)
def main(cfg):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    train(cfg, hydra_cfg)


if __name__ == "__main__":
    main()
