import shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from hydra.utils import instantiate
from ito_vision.discretizations import (
    DDBMDiscretization,
    KarrasDiscretization,
    LinearDiscretization,
)
from ito_vision.samplers import (
    AncestralSampler,
    EIODESampler,
    EulerMaruyamaSampler,
    LangevinHeunSampler,
    MeanODESampler,
    RungeKutta2Sampler,
)
from omegaconf import OmegaConf
from PIL import Image
from torchmetrics.image import PeakSignalNoiseRatio as PSNR
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from tqdm import tqdm

from src.utils.fid import InceptionV3, wasserstein_distance
from src.utils.niqe import NIQE


# function to calculate PSNR, SSIM, LPIPS for specific number of batches
@torch.no_grad()
def validate(
    model,
    dm,
    sampler=None,
    discretization=None,
    bsz_limit=10,
    real_imgs_gauss: Optional[Dict] = None,
    device="cuda:0",
):
    dm.setup(stage="val")

    calculate_inception = real_imgs_gauss is not None

    test_dl = dm.test_dataloader()

    model.val_psnr = PSNR(reduction="none", data_range=(-1, 1), dim=(1, 2, 3))
    model.val_ssim = SSIM(reduction="none", data_range=(-1, 1))
    model.val_lpips = LPIPS(reduction="none").eval()

    if sampler is not None:
        model.sampler = sampler

    if discretization is not None:
        model.discretization = discretization

    model.to(device)
    model.eval()

    metrics = {"psnr": [], "ssim": [], "lpips": [], "niqe": []}
    input_metrics = {"psnr": [], "ssim": [], "lpips": [], "niqe": []}

    if calculate_inception:
        inception_embeddings = []
        inception_net = InceptionV3(normalize_input=False).to(device)

    total = len(test_dl) if bsz_limit is None else min(bsz_limit, len(test_dl))

    for idx, batch in tqdm(enumerate(test_dl), total=total):
        y, x0 = batch["y"], batch["x0"]
        x0 = x0.to(device)
        y = y.to(device)

        y = y.clamp(-1, 1)

        if model.vae:
            if "y_latent" in batch.keys():
                y_latent = batch["y_latent"].to(device)
            else:
                y_latent = model.vae.encode(y)
        else:
            y_latent = y.clone()

        kwargs = dict(
            {
                k: (v.to(device) if type(v) is torch.Tensor else v)
                for k, v in batch.items()
                if k not in ["y", "x0"]
            }
        )

        pred_z0, _, _ = model.method.sample(
            model.discretization,
            model.sampler,
            model.backbone,
            model.method.base_distribution(y_latent),
            y_latent,
            return_trajectory=False,
            **kwargs,
        )

        if model.vae:
            pred_x0 = model.vae.decode(pred_z0)
        else:
            pred_x0 = pred_z0.clamp(-1, 1)

        metrics["psnr"].extend(model.val_psnr(pred_x0, x0).cpu().tolist())
        metrics["ssim"].extend(model.val_ssim(pred_x0, x0).cpu().tolist())
        metrics["lpips"].extend(model.val_lpips(pred_x0, x0).cpu().tolist())
        metrics["niqe"].extend(NIQE(pred_x0))

        if calculate_inception:
            inception_embeddings.append(inception_net(pred_x0).cpu())

        input_metrics["psnr"].extend(model.val_psnr(y, x0).cpu().tolist())
        input_metrics["ssim"].extend(model.val_ssim(y, x0).cpu().tolist())
        input_metrics["lpips"].extend(model.val_lpips(y, x0).cpu().tolist())
        metrics["niqe"].extend(NIQE(y))

        if idx == total:
            break

    total_psnr = float(np.mean(metrics["psnr"]))
    total_ssim = float(np.mean(metrics["ssim"]))
    total_lpips = float(np.mean(metrics["lpips"]))
    total_niqe = float(np.mean(metrics["niqe"]))

    if calculate_inception:
        inception_embeddings = torch.cat(inception_embeddings, dim=0).numpy()
        mu = inception_embeddings.mean(axis=0)
        sigma = np.cov(inception_embeddings.T)

        fid = wasserstein_distance(
            mu, sigma, real_imgs_gauss["mean"], real_imgs_gauss["std"]
        )
    else:
        fid = 0

    total_input_psnr = float(np.mean(input_metrics["psnr"]))
    total_input_ssim = float(np.mean(input_metrics["ssim"]))
    total_input_lpips = float(np.mean(input_metrics["lpips"]))
    total_input_niqe = float(np.mean(input_metrics["niqe"]))

    return (total_psnr, total_ssim, total_lpips, total_niqe, fid), (
        total_input_psnr,
        total_input_ssim,
        total_input_lpips,
        total_input_niqe,
    )


def plot_metrics(ax, metrics, names, NFEs, baseline, metric):
    ticks = list(range(len(NFEs)))
    tick_names = list(map(str, NFEs))

    for name in names:
        ax.plot(ticks, metrics[name][metric], marker="o", label=name)

    ax.set_xticks(ticks=ticks, labels=tick_names)

    # Baseline
    ax.axhline(y=baseline, color="black", linestyle="--", label="Baseline")

    ax.set_ylabel("Metric value")
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
    plt.subplots_adjust(right=0.75)
    ax.grid(axis="y")
    ax.set_title(metric)


def validate_with_different_samplers(
    root,
    model,
    dm,
    samplers,
    samplers_names,
    NFEs=[3, 5, 10, 30, 50, 100],
    bsz_limit=10,
    device="cuda:0",
):
    metrics = {
        sampler: {
            "psnr": [],
            "ssim": [],
            "lpips": [],
            "niqe": [],
            "fid": [],
        }
        for sampler in samplers_names
    }

    real_imgs_gauss = dm.get_inception_statistics()

    for N in NFEs:
        print(f"Test N={N}")

        for sampler_init, name in zip(samplers, samplers_names):
            print(f"Method: {name}")
            sampler = sampler_init(N)

            vals, baselines = validate(
                model,
                dm,
                sampler=sampler,
                bsz_limit=bsz_limit,
                real_imgs_gauss=real_imgs_gauss,
                device=device,
            )
            metrics[name]["psnr"].append(vals[0])
            metrics[name]["ssim"].append(vals[1])
            metrics[name]["lpips"].append(vals[2])
            metrics[name]["niqe"].append(vals[3])
            metrics[name]["fid"].append(vals[4])

    fig, ax = plt.subplots(5, 1, figsize=(30, 15))

    plot_metrics(ax[0], metrics, samplers_names, NFEs, baselines[0], metric="psnr")
    plot_metrics(ax[1], metrics, samplers_names, NFEs, baselines[1], metric="ssim")
    plot_metrics(ax[2], metrics, samplers_names, NFEs, baselines[2], metric="lpips")
    plot_metrics(ax[3], metrics, samplers_names, NFEs, baselines[3], metric="niqe")
    plot_metrics(ax[4], metrics, samplers_names, NFEs, 0, metric="fid")

    plt.tight_layout()
    plt.savefig(f"{root}/metrics_samplers.png")
    plt.close()

    # Construct the metrics dataframe
    records = []

    records.append(
        {
            "sampler": "input",
            "NFE": 0,
            "psnr": baselines[0],
            "ssim": baselines[1],
            "lpips": baselines[2],
            "niqe": baselines[3],
            "fid": 0,
        }
    )

    for name in samplers_names:
        for i, N in enumerate(NFEs):
            records.append(
                {
                    "sampler": name,
                    "NFE": N,
                    "psnr": metrics[name]["psnr"][i],
                    "ssim": metrics[name]["ssim"][i],
                    "lpips": metrics[name]["lpips"][i],
                    "niqe": metrics[name]["niqe"][i],
                    "fid": metrics[name]["fid"][i],
                }
            )

    df = pd.DataFrame(records)
    df.to_csv(f"{root}/metrics_samplers.csv", index=False)

    return df


def validate_with_different_discretization(
    root,
    model,
    dm,
    discretizations,
    discretization_names,
    NFEs=[3, 5, 10, 30, 50, 100],
    bsz_limit=10,
    device="cuda:0",
):
    metrics = {
        sampler: {
            "psnr": [],
            "ssim": [],
            "lpips": [],
            "niqe": [],
            "fid": [],
        }
        for sampler in discretization_names
    }

    real_imgs_gauss = dm.get_inception_statistics()

    for N in NFEs:
        print(f"Test N={N}")

        for disc_init, name in zip(discretizations, discretization_names):
            print(f"Method: {name}")
            disc = disc_init()

            vals, baselines = validate(
                model,
                dm,
                sampler=AncestralSampler(N, True),
                discretization=disc,
                bsz_limit=bsz_limit,
                real_imgs_gauss=real_imgs_gauss,
                device=device,
            )

            metrics[name]["psnr"].append(vals[0])
            metrics[name]["ssim"].append(vals[1])
            metrics[name]["lpips"].append(vals[2])
            metrics[name]["niqe"].append(vals[3])
            metrics[name]["fid"].append(vals[4])

    fig, ax = plt.subplots(5, 1, figsize=(30, 15))

    plot_metrics(
        ax[0], metrics, discretization_names, NFEs, baselines[0], metric="psnr"
    )
    plot_metrics(
        ax[1], metrics, discretization_names, NFEs, baselines[1], metric="ssim"
    )
    plot_metrics(
        ax[2], metrics, discretization_names, NFEs, baselines[2], metric="lpips"
    )
    plot_metrics(
        ax[3], metrics, discretization_names, NFEs, baselines[3], metric="niqe"
    )
    plot_metrics(ax[4], metrics, discretization_names, NFEs, 0, metric="fid")

    plt.tight_layout()
    plt.savefig(f"{root}/metrics_disc.png")
    plt.close()

    # Construct the metrics dataframe
    records = []

    records.append(
        {
            "discretization": "input",
            "NFE": 0,
            "psnr": baselines[0],
            "ssim": baselines[1],
            "lpips": baselines[2],
            "niqe": baselines[3],
            "fid": 0,
        }
    )

    for name in discretization_names:
        for i, N in enumerate(NFEs):
            records.append(
                {
                    "discretization": name,
                    "NFE": N,
                    "psnr": metrics[name]["psnr"][i],
                    "ssim": metrics[name]["ssim"][i],
                    "lpips": metrics[name]["lpips"][i],
                    "niqe": metrics[name]["niqe"][i],
                    "fid": metrics[name]["fid"][i],
                }
            )

    df = pd.DataFrame(records)
    df.to_csv(f"{root}/metrics_disc.csv", index=False)

    return df


@torch.no_grad()
def save_results(root, model, dm, sampler=None, save_images=25, device="cuda:0"):
    Path(f"{root}/results").mkdir()
    dm.setup(stage="val")

    test_dl = dm.test_dataloader()

    if sampler is not None:
        model.sampler = sampler

    model.to(device)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    saved = 0

    for idx, batch in enumerate(test_dl):
        y = batch["y"]
        y = y.to(device)

        if model.vae:
            y_latent = model.vae.encode(y)
        else:
            y_latent = y.clamp(-1, 1).clone()

        kwargs = dict(
            {
                k: (v.to(device) if type(v) is torch.Tensor else v)
                for k, v in batch.items()
                if k not in ["y", "x0"]
            }
        )

        pred_z0, _, _ = model.method.sample(
            model.discretization,
            model.sampler,
            model.backbone,
            model.method.base_distribution(y_latent),
            y_latent,
            return_trajectory=False,
            **kwargs,
        )

        if model.vae:
            pred_x0 = model.vae.decode(pred_z0)
        else:
            pred_x0 = pred_z0.clamp(-1, 1)

        pred_x0 = ((pred_x0 + 1.0) * 127.5).clamp(0, 255).byte().cpu()

        for img_idx in range(len(pred_x0)):
            if saved >= save_images:
                return

            saved += 1

            img = pred_x0[img_idx].permute(1, 2, 0).numpy()
            img = Image.fromarray(img)
            img.save(f"{root}/results/{saved}.png")

        if saved >= save_images:
            break


@torch.no_grad()
def explore_trajectories(
    root,
    model,
    dm,
    sampler=None,
    save_trajectories=10,
    calculate_n_batches=10,
    device="cuda:0",
):
    Path(f"{root}/xt_trajectories").mkdir(exist_ok=True)
    Path(f"{root}/x0_trajectories").mkdir(exist_ok=True)
    dm.setup(stage="val")

    test_dl = dm.test_dataloader()

    model.val_psnr = PSNR(reduction="none", data_range=(-1, 1), dim=0)
    model.val_ssim = SSIM(reduction="none")
    model.val_lpips = LPIPS(reduction="none").eval()

    if sampler is not None:
        model.sampler = sampler

    model.to(device)
    model.eval()

    saved = 0
    agg_lpipses = []

    for idx, batch in tqdm(
        enumerate(test_dl), total=calculate_n_batches, desc="Trajectories"
    ):
        if idx >= calculate_n_batches:
            break

        y, x0 = batch["y"], batch["x0"]
        y = y.to(device)
        x0 = x0.to(device)

        if model.vae:
            y_latent = model.vae.encode(y)
        else:
            y_latent = y.clamp(-1, 1).clone()

        kwargs = dict(
            {k: v.to(device) for k, v in batch.items() if k not in ["y", "x0"]}
        )

        _, traj_xt, traj_x0 = model.method.sample(
            model.discretization,
            model.sampler,
            model.backbone,
            model.method.base_distribution(y_latent),
            y_latent,
            return_trajectory=True,
            **kwargs,
        )

        if model.vae:
            img_traj_xt, img_traj_x0 = [], []

            for i in range(traj_xt.shape[1]):
                decoded_xt = (
                    model.vae.decode(traj_xt[:, i].to(device)).clamp(-1, 1).cpu()
                )
                decoded_x0 = (
                    model.vae.decode(traj_x0[:, i].to(device)).clamp(-1, 1).cpu()
                )

                img_traj_xt.append(decoded_xt)
                img_traj_x0.append(decoded_x0)

            traj_xt = torch.stack(img_traj_xt, dim=1)
            traj_x0 = torch.stack(img_traj_x0, dim=1)

        for i in range(traj_xt.shape[1]):
            if saved >= save_trajectories:
                break

            saved += 1

            Path(f"{root}/x0_trajectories/{saved}").mkdir(exist_ok=True)
            Path(f"{root}/xt_trajectories/{saved}").mkdir(exist_ok=True)

            for j in range(traj_xt.shape[0]):
                xt = (
                    (
                        (traj_xt[j, i].permute(1, 2, 0).cpu().clamp(-1, 1).numpy() + 1)
                        * 127.5
                    )
                    .clip(0, 255)
                    .astype(np.uint8)
                )
                x0_hat = (
                    (
                        (
                            traj_x0[j, i].permute(1, 2, 0).cpu().clamp(-1, 1).numpy()
                            + 1.0
                        )
                        * 127.5
                    )
                    .clip(0, 255)
                    .astype(np.uint8)
                )

                Image.fromarray(xt).save(f"{root}/xt_trajectories/{saved}/{j}.png")
                Image.fromarray(x0_hat).save(f"{root}/x0_trajectories/{saved}/{j}.png")

            selected_traj_x0 = traj_x0[:, i].to(device)
            selected_gt = x0[i].unsqueeze(0).repeat(len(selected_traj_x0), 1, 1, 1)
            selected_gt = selected_gt.clamp(-1.0, 1.0)

            lpipses = model.val_lpips(selected_traj_x0, selected_gt).cpu()
            agg_lpipses.append(lpipses)

    stacked = np.stack(agg_lpipses)
    lpipses = np.mean(stacked, axis=0)

    plt.figure(figsize=(18, 8))
    plt.plot(lpipses)
    plt.title("LPIPS of x0 prediction across timesteps")
    plt.savefig(f"{root}/pred_x0_lpips.png")
    plt.close()

    df = pd.DataFrame({"idx": np.arange(len(lpipses)), "value": lpipses})

    df.to_csv(f"{root}/pred_x0_lpips.csv", index=False)


def examine_model(
    root: str,
    NFEs: list = [3, 5, 10, 30, 50, 100],
    bsz_limit: int = 10,
    save_images: int = 25,
    save_trajectories: int = 10,
    examined_checkpoint: Literal["best", "last"] = "best",
    device: str = "cuda:0",
    new_bsz: Optional[int] = None,
    ds_root: Optional[str] = None,
):
    now = datetime.now()
    current_time = now.strftime("%Y-%m-%d_%H-%M-%S")

    path = Path(f"{root}/results/{current_time}")

    if path.exists():
        shutil.rmtree(path)

    if not path.exists():
        path.mkdir(parents=True)

    cfg = OmegaConf.load(f"{root}/.hydra/config.yaml")

    # Recreate model from config
    model = instantiate(cfg.model)

    # Load weights
    checkpoint = torch.load(
        f"{root}/checkpoints/{examined_checkpoint}.ckpt",
        map_location=device,
        weights_only=False,
    )
    model.load_state_dict(checkpoint["state_dict"])

    dm = instantiate(cfg.data)
    dm.num_workers = 64

    if new_bsz is not None:
        dm.val_bsz = new_bsz

    if ds_root is not None:
        print("change root to:", ds_root)
        dm.data_dir = ds_root

    sampler = AncestralSampler(N=35, quiet=True)
    save_results(str(path), model, dm, sampler, save_images, device)

    explore_trajectories(
        f"{root}/results",
        model,
        dm,
        sampler,
        save_trajectories,
        bsz_limit // 20,
        device,
    )

    samplers = {
        "Euler ODE": lambda N: EulerMaruyamaSampler(N, Lambda=0, quiet=True),
        "Euler SDE": lambda N: EulerMaruyamaSampler(N, Lambda=1, quiet=True),
        "Ancestral": lambda N: AncestralSampler(N, quiet=True),
        "EI-ODE": lambda N: EIODESampler(N, quiet=True),
        "mean-ODE": lambda N: MeanODESampler(N, quiet=True),
        "Langevin-Heun": lambda N: LangevinHeunSampler(N=int((N - 1) // 3), quiet=True),
        "2nd Heun": lambda N: RungeKutta2Sampler(
            N=int((N - 1) // 2), quiet=True, variant="heun"
        ),
        "2nd Midpoint": lambda N: RungeKutta2Sampler(
            N=int((N - 1) // 2), quiet=True, variant="midpoint"
        ),
        "2nd Ralston": lambda N: RungeKutta2Sampler(
            N=int((N - 1) // 2), quiet=True, variant="ralston"
        ),
    }

    validate_with_different_samplers(
        str(path),
        model,
        dm,
        list(samplers.values()),
        list(samplers.keys()),
        NFEs,
        bsz_limit,
        device,
    )

    discretizations = {
        "karras": KarrasDiscretization,
        "ddbm": DDBMDiscretization,
        "linear": LinearDiscretization,
    }

    validate_with_different_discretization(
        str(path),
        model,
        dm,
        list(discretizations.values()),
        list(discretizations.keys()),
        NFEs,
        bsz_limit,
        device,
    )


def extract_path(name, date):
    LOGS_ROOT = "logs/runs"

    if date != "latest":
        return f"{LOGS_ROOT}/{name}/{date}"

    subdirs = list(map(lambda n: n.stem, Path(f"{LOGS_ROOT}/{name}").iterdir()))

    subdirs.sort(key=lambda x: datetime.strptime(x, "%Y-%m-%d_%H-%M-%S"))

    latest = subdirs[-1]

    return f"{LOGS_ROOT}/{name}/{latest}"
