import torch
from konductor.data import get_dataset_properties
from konductor.init import ExperimentInitConfig
from konductor.metadata.base_statistic import STATISTICS_REGISTRY, Statistic
from torch import Tensor

from .dataset.sc2_dataset import TorchSC2Data
from .utils.position_transforms import PositionTransform, get_unit_target_positions


@STATISTICS_REGISTRY.register_module("sc2-accuracy")
class SC2Accuracy(Statistic):
    """Calculates accuracy metrics for StarCraft Motion Intention task"""

    @classmethod
    def from_config(cls, cfg: ExperimentInitConfig, **extras):
        props = get_dataset_properties(cfg)
        enable_pos = "pos_decoder" in cfg.model[0].args
        enable_unit = "unit_decoder" in cfg.model[0].args
        if enable_pos:
            enable_pos_score = (
                cfg.model[0].args["pos_decoder"]["args"].get("decode_logit", False)
            )
            pos_transform = PositionTransform.from_config(cfg)
        else:
            enable_pos_score = False
            pos_transform = None
        return cls(
            max_time=props["clip_length"],
            enable_unit=enable_unit,
            enable_pos=enable_pos,
            enable_pos_score=enable_pos_score,
            pos_transform=pos_transform,
        )

    def __init__(
        self,
        max_time: int,
        enable_unit: bool,
        enable_pos: bool,
        enable_pos_score: bool,
        pos_transform: PositionTransform | None,
        eval_stride: int = 5,
        pos_score_threshold: float = 0.5,
    ) -> None:
        super().__init__()
        self.max_time = max_time
        self.eval_stride = eval_stride
        self.pos_transform = pos_transform
        self.enable_unit = enable_unit
        self.enable_pos = enable_pos
        self.enable_pos_score = enable_pos_score
        self.pos_score_threshold = pos_score_threshold

    def get_keys(self) -> list[str]:
        keys: list[str] = []
        for tidx in range(0, self.max_time, self.eval_stride):
            if self.enable_unit:
                keys.extend(
                    [
                        f"top1-{tidx}",
                        f"top5-{tidx}",
                        f"top1-{tidx}-null",
                        f"top5-{tidx}-null",
                    ]
                )
            if self.enable_pos:
                keys.append(f"l2-{tidx}")
            if self.enable_pos_score:
                keys.extend(
                    [
                        f"pos-score-acc-{tidx}",
                        f"pos-score-precision-{tidx}",
                        f"pos-score-recall-{tidx}",
                        f"pos-score-f1-{tidx}",
                    ]
                )
        return keys

    def position_accuracy(self, pred: Tensor, data: TorchSC2Data):
        """Calculate position accuracy as l2 distance to target per-unit"""
        assert self.pos_transform is not None
        truth, mask = get_unit_target_positions(data)
        results: dict[str, float] = {}
        for tidx in range(0, self.max_time, self.eval_stride):
            pred_xy = self.pos_transform(
                pred[tidx, mask[tidx]],
                (
                    data.units[tidx, mask[tidx]][..., :2]
                    if self.pos_transform.rel_pos
                    else None
                ),
            )
            diff = torch.norm(truth[tidx, mask[tidx]] - pred_xy, 2, dim=-1)
            results[f"l2-{tidx}"] = diff.mean().item()
        return results

    def target_accuracy(self, pred: Tensor, data: TorchSC2Data):
        """Calculate target accuracy as top1 and top5 assignment"""
        results: dict[str, float] = {}
        for tidx in range(0, self.max_time, self.eval_stride):
            target = data.unit_targets[tidx] + 1
            valid = data.units_mask[tidx]
            _, top_ind = torch.topk(pred[tidx], 5, sorted=True)
            top5_acc = torch.sum(target.unsqueeze(-1) == top_ind, dim=-1)
            top1_acc = target == top_ind[..., 0]
            not_null = target != 0 & valid  # Mask for valid and non-null assignment
            results[f"top1-{tidx}"] = torch.mean(
                top1_acc[not_null], dtype=torch.float32
            ).item()
            results[f"top5-{tidx}"] = torch.mean(
                top5_acc[not_null], dtype=torch.float32
            ).item()
            results[f"top1-{tidx}-null"] = torch.mean(
                top1_acc[valid], dtype=torch.float32
            ).item()
            results[f"top5-{tidx}-null"] = torch.mean(
                top5_acc[valid], dtype=torch.float32
            ).item()

        return results

    def position_score_accuracy(self, preds: Tensor, data: TorchSC2Data):
        """Calculates the accuracy of the position validity logit"""
        preds = preds.squeeze(-1)
        result: dict[str, Tensor] = {}
        _, position_mask = get_unit_target_positions(data)
        for tidx in range(0, self.max_time, self.eval_stride):
            mask = data.units_mask[tidx]
            target = position_mask[tidx, mask]
            pred = preds[tidx, mask].sigmoid() > self.pos_score_threshold
            correct = target == pred
            true_positive = correct[target].sum()
            precision = true_positive / pred.sum()
            recall = true_positive / target.sum()
            result[f"pos-score-acc-{tidx}"] = correct.sum() / target.nelement()
            result[f"pos-score-precision-{tidx}"] = precision
            result[f"pos-score-recall-{tidx}"] = recall
            result[f"pos-score-f1-{tidx}"] = (
                2 * precision * recall / (precision + recall)
            )
        return {k: v.item() for k, v in result.items()}

    def __call__(self, pred: dict[str, Tensor], data: TorchSC2Data) -> dict[str, float]:
        res: dict[str, float] = {}
        if self.enable_unit:
            res.update(self.target_accuracy(pred["unit-target"], data))
        if self.enable_pos:
            res.update(self.position_accuracy(pred["position"], data))
        if self.enable_pos_score:
            res.update(self.position_score_accuracy(pred["pos-logit"], data))
        return res
