import datetime
import itertools
from collections import defaultdict
from gzip import GzipFile
from pathlib import Path
from typing import Literal

import cv2
import lightning as pl
import numpy as np
import pyiqa.losses
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from scipy.stats import pearsonr, spearmanr
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from main import get_conf_dict, get_img_pathes_gt


IMAGES = "<path to dataset>"
# Folder with heatmaps computed with evaluate.py.
HEATMAPS_ROOT = Path("<path to heatmaps>")
DESRA_IMAGES = "<path to desra images dataset>"

RF = "bicubic"

# You need these heatmaps in the folder.
FEATURES = ["dists", "ssm_jup", "bd_jup"]
# Fits into 24 GB VRAM.
BATCH_SIZE = 10
# Pick your GPU.
DEVICE = 2

MAIN_METRIC = "val-oi-gtRLFN/plcc_gtarea"
# Checkpoints, TensorBoard logs, etc.
OUTPUT_ROOT = Path("../ptl-logs")


def gz_npy_load(path: Path) -> np.ndarray:
    with GzipFile(path) as f:
        return np.load(f)


def load_ann(path: Path, erode: bool) -> np.ndarray:
    ann = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    assert ann is not None

    if not erode:
        return ann == 255

    ann = np.where(ann == 255, np.uint8(255), np.uint8(0))

    kds = max(64, min(ann.shape[0], ann.shape[1]) // 12)
    kd = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, [kds] * 2)
    ann_eroded = cv2.morphologyEx(ann, cv2.MORPH_ERODE, kd, borderValue=[0])

    # Erosion can get rid of the entire annotation, in this case just use the original.
    if np.count_nonzero(ann_eroded) > 0:
        ann = ann_eroded
    else:
        # print("erosion removed entire annotation for", path.stem)
        pass

    return ann == 255


Batch = tuple[torch.Tensor, torch.Tensor, torch.Tensor]


class HeatmapDataset(Dataset):
    def __init__(
        self,
        conf_dict: dict[str, dict],
        img_paths: dict[str, tuple[str, str, str]],
        heatmaps_root: Path,
        metrics: list[str],
        erode_ann: bool,
    ):
        self.conf_dict = conf_dict
        self.img_paths = sorted(img_paths.items())
        self.heatmaps_root = heatmaps_root
        self.metrics = metrics
        self.erode_ann = erode_ann

    def __len__(self) -> int:
        return len(self.img_paths)

    def __getitem__(self, idx: int) -> Batch:
        mask_fn, (hr, sr, rf) = self.img_paths[idx]
        sr = Path(sr)
        sr_name = sr.stem

        ann = torch.tensor(load_ann(sr.parent / mask_fn, erode=self.erode_ann))

        heatmaps = []
        for metric in self.metrics:
            path = self.heatmaps_root / metric / (sr_name + ".npy.gz")
            hm = gz_npy_load(path)
            heatmaps.append(torch.tensor(hm, dtype=torch.float32))
        heatmaps = torch.stack(heatmaps, dim=-1)

        ratio_bi = self.conf_dict[mask_fn]["ratio_bi"]
        ratio_bi = torch.tensor(ratio_bi, dtype=torch.float32)

        return (heatmaps, ann, ratio_bi)


class ArtifactMetric(nn.Module):
    def __init__(self, n_features: int):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(n_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logits = self.linear_relu_stack(x)
        return logits


class LitMetric(pl.LightningModule):
    def __init__(
        self,
        *,
        gt_area_only: bool,
        train_gt: Literal["gt", "RLFN", "SPAN"],
        plcc_loss: bool,
        features: list[str],
    ):
        super().__init__()
        self.save_hyperparameters()
        self.gt_area_only = gt_area_only
        self.train_gt = train_gt
        self.plcc_loss = plcc_loss
        self.features = features

        self.model = ArtifactMetric(len(features))

        # Must match the dataloaders returned from val_dataloader().
        self.dl_names = ["oi-gtRLFN", "oi-gtgt", "desra-gtRLFN", "desra-gtMSESR"]
        self.val_outputs = [defaultdict(list) for _ in range(len(self.dl_names))]

    def setup(self, stage: str):
        if stage == "fit":
            conf_dict = get_conf_dict("gt_conf.csv")

            img_pathes_train = get_img_pathes_gt(
                str(IMAGES), conf_dict, subset="train", gt=self.train_gt, rf=RF
            )
            heatmaps_train = HEATMAPS_ROOT / f"heatmaps_gt{self.train_gt}_rf{RF}"
            self.train_ds = HeatmapDataset(
                conf_dict, img_pathes_train, heatmaps_train, self.features, erode_ann=True
            )

            self.val_ds = []
            for val_gt in ["RLFN", "gt"]:
                img_pathes_val = get_img_pathes_gt(
                    str(IMAGES), conf_dict, subset="test", gt=val_gt, rf=RF
                )
                heatmaps_val = HEATMAPS_ROOT / f"heatmaps_gt{val_gt}_rf{RF}"
                ds_val = HeatmapDataset(
                    conf_dict, img_pathes_val, heatmaps_val, self.features, erode_ann=True
                )
                self.val_ds.append(ds_val)

            conf_dict_desra = get_conf_dict("gt_conf_desra.csv")
            img_pathes_desra_og = get_img_pathes_gt(
                str(DESRA_IMAGES), conf_dict_desra, subset="full", gt="RLFN", rf=RF
            )
            img_pathes_desra = {
                mask_fn: paths
                for mask_fn, paths in img_pathes_desra_og.items()
                # Filter out images too big to fit into 24 GB VRAM.
                if Path(paths[1]).stem
                not in [
                    "00002552@SR@SwinIR",
                    "00005200@SR@SwinIR",
                    "00009133@SR@SwinIR",
                    "00020673@SR@SwinIR",
                    "00027141@SR@SwinIR",
                    "00036007@SR@SwinIR",
                    "00041068@SR@SwinIR",
                ]
            }
            assert len(img_pathes_desra) < len(img_pathes_desra_og)
            heatmaps_desra = HEATMAPS_ROOT / f"desra_heatmaps_gtRLFN_rf{RF}"
            ds_desra_rlfn = HeatmapDataset(
                conf_dict_desra, img_pathes_desra, heatmaps_desra, self.features, erode_ann=False
            )
            heatmaps_desra = HEATMAPS_ROOT / f"desra_heatmaps_gtMSESR_rf{RF}"
            ds_desra_msesr = HeatmapDataset(
                conf_dict_desra, img_pathes_desra, heatmaps_desra, self.features, erode_ann=False
            )
            self.val_ds_desra = [ds_desra_rlfn, ds_desra_msesr]

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=BATCH_SIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True,
        )

    def val_dataloader(self):
        return [
            DataLoader(
                val_ds,
                batch_size=BATCH_SIZE,
                num_workers=8,
            )
            for val_ds in self.val_ds
        ] + [
            DataLoader(val_ds_desra, batch_size=1, num_workers=8)
            for val_ds_desra in self.val_ds_desra
        ]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def compute_loss(
        self, batch: Batch
    ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
        x, ann, y = batch

        pred = self.model(x)
        pred = pred.squeeze(-1)

        pred_gtarea = torch.where(ann, pred, torch.nan)
        mean_gtarea = torch.nanmean(pred_gtarea, dim=(1, 2))
        loss_gtarea = F.mse_loss(mean_gtarea, y)

        loss_gtarea_plcc = pyiqa.losses.iqa_losses.plcc_loss(mean_gtarea, y)

        pred_outside = torch.where(~ann, pred, torch.nan)
        mean_outside = torch.nanmean(pred_outside, dim=(1, 2))
        loss_outside = F.mse_loss(mean_outside, torch.zeros_like(mean_outside))

        loss = loss_gtarea
        if not self.gt_area_only:
            loss += loss_outside
        if self.plcc_loss:
            loss += loss_gtarea_plcc

        return (
            pred,
            mean_gtarea,
            {
                "loss_gtarea": loss_gtarea,
                "loss_gtarea_plcc": loss_gtarea_plcc,
                "loss_outside": loss_outside,
                "loss": loss,
            },
        )

    def on_train_epoch_start(self):
        torch.cuda.empty_cache()

    def training_step(self, batch: Batch, batch_idx: int):
        pred, mean_gtarea, losses = self.compute_loss(batch)
        self.log_dict({f"train/{name}": value for name, value in losses.items()})
        return losses["loss"]

    def validation_step(self, batch: Batch, batch_idx: int, dataloader_idx: int):
        x, ann, y = batch
        pred, mean_gtarea, losses = self.compute_loss(batch)

        thresholds = torch.linspace(0, 0.9, 10, device=pred.device)
        pred = pred.unsqueeze(1)
        pred_bin_mask = pred > thresholds.view(1, -1, 1, 1)
        pred_bin = torch.where(pred_bin_mask, pred, torch.nan)
        mean_bin = torch.nanmean(pred_bin, dim=(2, 3))

        dl_name = self.dl_names[dataloader_idx]
        self.log_dict(
            {f"val-{dl_name}/{name}": value for name, value in losses.items()},
            add_dataloader_idx=False,
        )
        val_outputs = self.val_outputs[dataloader_idx]
        val_outputs["mean_gtarea"].append(mean_gtarea)
        val_outputs["mean_bin"].append(mean_bin)
        val_outputs["y"].append(y)

        if "desra" in dl_name:
            ann = ann.unsqueeze(1)
            intersection = pred_bin_mask & ann
            intersection = torch.count_nonzero(intersection, dim=(2, 3))
            union = pred_bin_mask | ann
            union = torch.count_nonzero(union, dim=(2, 3))
            val_outputs["intersection"].append(intersection)
            val_outputs["union"].append(union)

    def on_validation_epoch_end(self):
        for val_outputs, dl_name in zip(self.val_outputs, self.dl_names):
            scores = {}

            mean_gtarea = torch.cat(val_outputs["mean_gtarea"])
            y = torch.cat(val_outputs["y"])
            mean_bin = torch.cat(val_outputs["mean_bin"])

            if "desra" in dl_name:
                intersection = torch.cat(val_outputs["intersection"])
                union = torch.cat(val_outputs["union"])
                iu = torch.stack((intersection, union))
                iou = torch.sum(iu[0], dim=0) / torch.sum(iu[1], dim=0)
                iou_max, iou_max_idx = torch.max(iou, dim=0)

                iu_50 = torch.where((y > 0.5).unsqueeze(1), iu, torch.zeros_like(iu))
                iou_50 = torch.sum(iu_50[0], dim=0) / torch.sum(iu_50[1], dim=0)
                iou_50_max, iou_50_max_idx = torch.max(iou_50, dim=0)

                scores.update(
                    {
                        "iou_max": iou_max.item(),
                        "iou_max_threshold": iou_max_idx.item() / 10,
                        "iou_50_max": iou_50_max.item(),
                        "iou_50_max_threshold": iou_50_max_idx.item() / 10,
                    }
                )

            for v in val_outputs.values():
                v.clear()

            plcc_gtarea = torch.corrcoef(torch.stack((mean_gtarea, y)))[0, 1].item()
            srcc_gtarea = spearmanr(mean_gtarea.numpy(force=True), y.numpy(force=True)).statistic
            scores.update(
                {
                    "plcc_gtarea": plcc_gtarea,
                    "srcc_gtarea": srcc_gtarea,
                }
            )

            y = y.numpy(force=True)

            mean_bin_zero = torch.nan_to_num(mean_bin)
            plccs_bin_zero, srccs_bin_zero = [], []
            for idx, mean_bin_zero_t in enumerate(mean_bin_zero.unbind(1)):
                mean_bin_zero_t = mean_bin_zero_t.numpy(force=True)
                plccs_bin_zero.append(pearsonr(mean_bin_zero_t, y).statistic)
                srccs_bin_zero.append(spearmanr(mean_bin_zero_t, y).statistic)

            plcc_idx = np.argmax(plccs_bin_zero)
            plcc_bin = plccs_bin_zero[plcc_idx]
            plcc_thres = torch.tensor(plcc_idx / 10)
            srcc_idx = np.argmax(srccs_bin_zero)
            srcc_bin = srccs_bin_zero[srcc_idx]
            srcc_thres = torch.tensor(srcc_idx / 10)
            scores.update(
                {
                    "plcc_bin_zero_max": plcc_bin,
                    "plcc_bin_zero_max_threshold": plcc_thres,
                    "srcc_bin_zero_max": srcc_bin,
                    "srcc_bin_zero_max_threshold": srcc_thres,
                }
            )

            idxs, plccs_bin, srccs_bin = [], [], []
            for idx, mean_bin_t in enumerate(mean_bin.unbind(1)):
                mean_bin_t = mean_bin_t.numpy(force=True)
                nans = np.isnan(mean_bin_t)
                # If more than 50% of the dataset is NaNs, skip.
                if np.count_nonzero(nans) > len(mean_bin_t) / 2:
                    continue

                mean_bin_t = mean_bin_t[~nans]
                y_na = y[~nans]
                idxs.append(idx)
                plccs_bin.append(pearsonr(mean_bin_t, y_na).statistic)
                srccs_bin.append(spearmanr(mean_bin_t, y_na).statistic)

            if len(plccs_bin) > 0:
                plcc_idx = np.argmax(plccs_bin)
                plcc_bin = plccs_bin[plcc_idx]
                plcc_thres = torch.tensor(idxs[plcc_idx] / 10)

                srcc_idx = np.argmax(srccs_bin)
                srcc_bin = srccs_bin[srcc_idx]
                srcc_thres = torch.tensor(idxs[srcc_idx] / 10)

                scores.update(
                    {
                        "plcc_bin_max": plcc_bin,
                        "plcc_bin_max_threshold": plcc_thres,
                        "srcc_bin_max": srcc_bin,
                        "srcc_bin_max_threshold": srcc_thres,
                    }
                )

            self.log_dict(
                {f"val-{dl_name}/{name}": value for name, value in scores.items()},
                add_dataloader_idx=False,
            )

            if dl_name == "oi-gtRLFN":
                self.log("hp_metric", plcc_gtarea)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


date_str = datetime.datetime.now().strftime("%Y%m%d")

loops = [
    [False],
    ["gt"],
]

for (
    gt_area_only,
    train_gt,
) in itertools.product(*loops):
    # Reproducibility.
    seed = 42
    pl.seed_everything(seed, workers=True)

    plcc_loss = False

    features = list(FEATURES)

    components = [f"gt{train_gt}"]
    if gt_area_only:
        components.append("gtareaonly")
    if plcc_loss:
        components.append("plccloss")
    if seed != 42:
        components.append(f"s{seed}")

    save_dir = OUTPUT_ROOT / (date_str + "-" + "-".join(components))
    logger = TensorBoardLogger(save_dir, name=None)

    model = LitMetric(
        gt_area_only=gt_area_only,
        train_gt=train_gt,
        plcc_loss=plcc_loss,
        features=features,
    )

    callbacks: list[pl.Callback] = [
        ModelCheckpoint(
            filename=f"epoch={{epoch}}-step={{step}}-val_plcc_gtarea={{{MAIN_METRIC}:.3f}}",
            auto_insert_metric_name=False,
            monitor=MAIN_METRIC,
            mode="max",
            save_top_k=-1,
            every_n_epochs=1,
        ),
        # TQDMProgressBar(leave=True),
        # EarlyStopping(
        #     monitor=MAIN_METRIC,
        #     mode="max",
        #     patience=5,
        # ),
    ]

    trainer = pl.Trainer(
        max_epochs=50,
        log_every_n_steps=1,
        callbacks=callbacks,
        logger=logger,
        # accelerator="cpu",
        devices=[DEVICE],
        deterministic=True,
        profiler="simple",
        # fast_dev_run=True,
    )
    trainer.fit(model)
