""" 
General AP Statistics for Occupancy Heatmap
"""

from dataclasses import dataclass
from itertools import product

import numpy as np
import torch
from konductor.data import get_dataset_properties
from konductor.init import ExperimentInitConfig
from konductor.metadata.base_statistic import STATISTICS_REGISTRY, Statistic
from torch import Tensor

from .dataset.chasing_targets import OccupancyParams


@dataclass
class Confusion:
    """Batch-Last Confusion Array Shape: (thresh, batchidx)"""

    @classmethod
    def preallocate(cls, batch: int, thresholds: int, device=None):
        data = torch.empty((thresholds, batch), dtype=torch.float32, device=device)
        return cls(data, data.clone(), data.clone(), data.clone())

    tp: Tensor
    fp: Tensor
    tn: Tensor
    fn: Tensor

    @property
    def device(self):
        """Device tensors are currently on"""
        return self.tp.device


def _div_no_nan(a: Tensor, b: Tensor) -> Tensor:
    """Divide and set nan/inf values to zero"""
    c = a / b
    c[~torch.isfinite(c)] = 0
    return c


@STATISTICS_REGISTRY.register_module("occupancy")
class Occupancy(Statistic):
    """Soft IoU and AUC for Occupancy"""

    @classmethod
    def from_config(cls, cfg: ExperimentInitConfig, **extras):
        dataset_props = get_dataset_properties(cfg)
        params: OccupancyParams = dataset_props["occupancy"]
        time_idxs = set(params.times)
        if params.random_range is not None:
            _range = params.random_range
            time_idxs.update(set(range(_range.min, _range.max + 1)))

        classes = (
            cfg.model[0]
            .args["decoder"]["adapter"]["args"]
            .get("names", dataset_props.get("classes", None))
        )

        return cls(
            time_idxs=sorted(list(time_idxs)) if len(time_idxs) > 0 else None,
            classes=classes,
            **extras,
        )

    def get_keys(self) -> list[str]:
        # Create statistic keys
        data_keys = ["IoU", "AUC"]
        if self.time_idxs is not None:  # Statistic_Time
            data_keys = [f"{s}_{t}" for s, t in product(data_keys, self.time_idxs)]
        if self.classes is not None:  # Class_Statistic
            data_keys = [f"{c}_{s}" for c, s in product(self.classes, data_keys)]
        return data_keys

    def __init__(
        self,
        time_idxs: list[int] | None = None,
        classes: list[str] | None = None,
        auc_thresholds: int = 100,
    ):
        self.auc_thresholds = auc_thresholds
        self.time_idxs = time_idxs
        self.classes = classes

    def calculate_soft_iou(self, pred: Tensor, target: Tensor) -> np.ndarray:
        """Calculates heatmap iou"""
        soft_intersection = (pred * target).sum(dim=(1, 2))
        soft_union = (pred + target - pred * target).sum(dim=(1, 2))
        soft_iou = (soft_intersection / soft_union).cpu().numpy()
        return soft_iou

    def make_thresholds(self) -> np.ndarray:
        # ensure 0,0 -> 1,1 with 1 and 0 thresholds
        # thresholds = np.concatenate(
        #     [
        #         np.linspace(1, 0.8, 21),
        #         np.linspace(0.7, 0.3, 5),
        #         np.linspace(0.20, 0, 21),
        #     ]
        # )

        thresh = np.linspace(0, 1, self.auc_thresholds, dtype=np.float32)

        # Go beyond 0,1 to capture float rounding issues
        thresh[0] = -np.finfo(thresh.dtype).eps
        thresh[-1] = 1 + np.finfo(thresh.dtype).eps
        return thresh

    def calculate_confusion(self, pred: Tensor, target: Tensor) -> Confusion:
        """"""
        target_binary = target.bool()
        thresholds = self.make_thresholds()
        conf = Confusion.preallocate(pred.shape[0], thresholds.shape[0], pred.device)

        # Thresholds should ordered 0 -> 1
        for idx, threshold in enumerate(thresholds):
            pred_binary: Tensor = pred > threshold
            conf.fn[idx] = (~pred_binary & target_binary).sum(dim=(1, 2))
            conf.tp[idx] = (pred_binary & target_binary).sum(dim=(1, 2))
            conf.fp[idx] = (pred_binary & ~target_binary).sum(dim=(1, 2))
            conf.tn[idx] = (~pred_binary & ~target_binary).sum(dim=(1, 2))

        return conf

    def interpolate_pr_auc(self, confusion: Confusion) -> np.ndarray:
        """From Keras PR AUC Interpolation"""
        zero_ = torch.tensor(0, device=confusion.device)

        dtp = confusion.tp[:-1] - confusion.tp[1:]
        p = confusion.tp + confusion.fp
        dp = p[:-1] - p[1:]
        prec_slope = _div_no_nan(dtp, torch.maximum(dp, zero_))
        intercept = confusion.tp[1:] - prec_slope * p[1:]

        safe_p_ratio = torch.where(
            torch.logical_and(p[:-1] > 0, p[1:] > 0),
            _div_no_nan(p[:-1], torch.maximum(p[1:], zero_)),
            torch.ones_like(p[1:]),
        )

        pr_auc_increment = _div_no_nan(
            prec_slope * (dtp + intercept * torch.log(safe_p_ratio)),
            torch.maximum(confusion.tp[1:] + confusion.fn[1:], zero_),
        )

        return pr_auc_increment.sum(dim=0).cpu().numpy()

    def calculate_auc(self, pred: Tensor, target: Tensor) -> np.ndarray:
        """Calculate heatmap auc"""
        conf = self.calculate_confusion(pred, target)
        auc = self.interpolate_pr_auc(conf)
        return auc

    def run_over_timesteps(
        self, prediction: Tensor, target: Tensor, timesteps: Tensor, classname: str = ""
    ):
        assert all(
            (t == timesteps[0]).all() for t in timesteps
        ), "Found different timesteps in batches"

        results: dict[str, float] = {}

        for tidx, timestep in enumerate(timesteps[0, 0]):
            results[f"{classname}IoU_{timestep}"] = self.calculate_soft_iou(
                prediction[:, tidx], target[:, tidx]
            ).mean()

            results[f"{classname}AUC_{timestep}"] = self.calculate_auc(
                prediction[:, tidx], target[:, tidx]
            ).mean()

        return results

    def run_over_class(
        self, prediction: Tensor, target: Tensor, timesteps: Tensor, classname: str = ""
    ):
        """Run class over timestep(s)"""
        prediction = prediction.sigmoid()

        if self.time_idxs is None:
            # squeeze channel dim on prediction if required
            prediction = prediction.squeeze(1)
            target = target.squeeze(1)

            results: dict[str, float] = {}
            results[f"{classname}IoU"] = self.calculate_soft_iou(
                prediction, target
            ).mean()
            results[f"{classname}AUC"] = self.calculate_auc(prediction, target).mean()
        else:
            results = self.run_over_timesteps(prediction, target, timesteps, classname)

        return results

    def __call__(self, predictions: dict[str, Tensor], targets: dict[str, Tensor]):
        """"""
        results: dict[str, float] = {}
        for key in predictions:
            # Match class name and prediction key
            if self.classes is not None:
                name = [c for c in self.classes if c in key][0]
                name = f"{name}_"
            else:
                name = ""

            results.update(
                self.run_over_class(
                    predictions[key], targets[key], targets["time_idx"], name
                )
            )
        return results


@STATISTICS_REGISTRY.register_module("goal-accuracy")
class GoalAccuracy(Statistic):
    @classmethod
    def from_config(cls, cfg: ExperimentInitConfig, **extras):
        props = get_dataset_properties(cfg)
        return cls(max_time=props["n_iter"])

    def __init__(self, max_time: int, eval_stride: int = 1) -> None:
        super().__init__()
        self.max_time = max_time
        self.eval_stride = eval_stride

    def get_keys(self) -> list[str] | None:
        return [f"Top1_{t}" for t in range(0, self.max_time, self.eval_stride)]

    def __call__(
        self, predictions: dict[str, Tensor], data: dict[str, Tensor]
    ) -> dict[str, float]:
        # Argmax is agnostic to logit or softmax output
        pred = torch.argmax(predictions["agent_target"], dim=-1)
        num_true = torch.sum(
            (pred == data["agent_target"]) & data["agents_valid"].bool(), dim=(0, 1)
        )
        num_valid = torch.sum(data["agents_valid"], dim=(0, 1))
        acc = (num_true / num_valid).tolist()
        return {
            f"Top1_{tidx}": acc[tidx]
            for tidx in range(0, self.max_time, self.eval_stride)
        }
