"""Loss function and utilities for occupancy prediction for heatmap
"""

from dataclasses import dataclass
from typing import Any, Literal

import torch
from konductor.losses import REGISTRY, ExperimentInitConfig, LossConfig
from torch import Tensor, nn
from torch.nn.functional import (
    binary_cross_entropy_with_logits,
    cross_entropy,
    l1_loss,
    mse_loss,
    nll_loss,
)

from .dataset.sc2_dataset import TorchSC2Data
from .utils.position_transforms import get_unit_target_positions


class OccupancyFocal(nn.Module):
    """
    Focal loss for occupancy prediction based on matching
    keys between prediction and ground truth.
    """

    def __init__(
        self, alpha: float = 0.25, gamma: float = 2, pos_weight: float = 1.0
    ) -> None:
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.pos_weight = torch.tensor(pos_weight).cuda()

    def _focal(self, prediction: Tensor, target: Tensor) -> Tensor:
        """Apply loss to prediction, target pair"""
        prob = prediction.sigmoid()
        ce_loss = binary_cross_entropy_with_logits(
            prediction, target, pos_weight=self.pos_weight, reduction="none"
        )
        p_t = prob * target + (1 - prob) * (1 - target)
        loss = ce_loss * ((1 - p_t + torch.finfo(prob.dtype).eps) ** self.gamma)

        if self.alpha >= 0:
            alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
            loss = alpha_t * loss

        return loss.mean()

    def forward(
        self, predictions: dict[str, Tensor], targets: dict[str, Tensor]
    ) -> dict[str, Tensor]:
        """Calculate loss by matching keys"""
        losses = {
            f"focal_{name}": self._focal(predictions[name], targets[name])
            for name in predictions
        }
        return losses


@dataclass
@REGISTRY.register_module("occupancy-focal")
class OccupancyFocalLoss(LossConfig):
    alpha: float = 1.0
    gamma: float = 0.25
    pos_weight: float = 1.0

    def get_instance(self) -> Any:
        return OccupancyFocal(**self.__dict__)


class GoalAssignmentLoss(nn.Module):
    """Cross entropy loss between predicted and real target of each agent"""

    def __init__(self, use_nll: bool):
        super().__init__()
        self.use_nll = use_nll

    def forward(self, predictions: dict[str, Tensor], targets: dict[str, Tensor]):
        """Apply loss for agent and target pairs"""
        agent_mask = targets["agents_valid"].bool()
        n_goals = targets["targets_valid"].sum(dim=1, keepdim=True, dtype=torch.long)
        loss = 0
        for n_goal in n_goals.unique().tolist():
            mask = (n_goals == n_goal).expand_as(agent_mask) & agent_mask
            pred_goals = predictions["agent_target"][mask][..., :n_goal]
            gt_goals = targets["agent_target"][mask]
            if self.use_nll:
                loss += nll_loss(torch.log(pred_goals), gt_goals)
            else:
                loss += cross_entropy(pred_goals, gt_goals)

        return {"goal_ce_loss": loss}


@dataclass
@REGISTRY.register_module("goal-ce")
class GoalAssignConfig(LossConfig):
    use_nll: bool = False

    def get_instance(self, *args, **kwargs) -> Any:
        return GoalAssignmentLoss(self.use_nll)


class SC2UnitLoss(nn.Module):
    """Cross-entropy between predicted and real unit target assignment."""

    def __init__(self, null_weight: float, use_ce: bool) -> None:
        super().__init__()
        self.null_weight = null_weight
        self.use_ce = use_ce

    def forward(self, model_out: dict[str, Tensor], batch_data: TorchSC2Data):
        """"""
        pred: Tensor = model_out["unit-target"]
        if self.null_weight != 1.0:
            weight = torch.ones(pred.shape[-1], device=pred.device, dtype=pred.dtype)
            weight[0] = self.null_weight
        else:
            weight = None

        if batch_data.enemy_mask is not None:
            n_targets = batch_data.enemy_mask.sum(
                dim=-1, keepdim=True, dtype=torch.long
            )
        else:
            n_targets = batch_data.units_mask.sum(
                dim=-1, keepdim=True, dtype=torch.long
            )
        n_targets += 1  # add null target

        # Add 1 to unit_targets to include null-assignment token which is -1 in gt
        target = batch_data.unit_targets + 1

        losses: list[Tensor] = []
        for n_target in n_targets.unique().tolist():
            mask = batch_data.units_mask & (n_targets == n_target).expand_as(
                batch_data.units_mask
            )
            if mask.sum() == 0:
                continue
            pred_ = pred[mask][..., :n_target]
            target_ = target[mask]
            weight_ = weight[:n_target] if weight is not None else None

            if self.use_ce:
                losses.append(cross_entropy(pred_, target_, weight=weight_))
            else:
                # model output is 'softmax', hence log and nll loss to get ce
                losses.append(nll_loss(torch.log(pred_), target_, weight=weight_))

        loss: Tensor = sum(losses) / len(losses)
        # if not torch.isfinite(loss).item():
        #     raise RuntimeError("Nan Loss!")

        return {"unit-assignment": loss}


@dataclass
@REGISTRY.register_module("sc2-unit-target")
class SC2UnitCfg(LossConfig):
    null_weight: float = 1.0
    use_ce: bool = True

    def get_instance(self, *args, **kwargs) -> Any:
        return SC2UnitLoss(self.null_weight, self.use_ce)


def _make_relative(pos: Tensor, mask: Tensor, batch_data: TorchSC2Data):
    """Make pos relative to the unit"""
    return pos - batch_data.units[mask][:, :2]


def _get_target_positions(batch_data: TorchSC2Data):
    """Get target unit positions as another loss (and mask where this applies)"""
    has_target = batch_data.unit_targets > -1
    has_target &= batch_data.units_mask
    if batch_data.enemy_units is None:
        target_pos = batch_data.units
    else:
        assert batch_data.enemy_mask is not None
        target_pos = batch_data.enemy_units
    target_pos = target_pos[..., :2]
    pad_pos = torch.cat(
        [torch.zeros_like(target_pos[..., 0, None, :]), target_pos], dim=-2
    )
    pos_gt = torch.gather(
        pad_pos, 2, 1 + batch_data.unit_targets[..., None].expand(-1, -1, -1, 2)
    )
    return pos_gt[has_target], has_target


def _convert_cart_to_polar(cartesian: Tensor):
    """Convert cartesian x,y to polar r,t.

    Args:
        cartesian (Tensor): Batch of cartesian coordinates to convert [N,2].

    Returns:
        Tensor: Returns batch tensor of same shap but (radius,theta) as last dimension.
    """
    assert cartesian.shape[-1] == 2, "Last dimension must equal 2 (x,y)"
    radius = torch.norm(cartesian[:, :2], p=2, dim=-1, keepdim=True)
    theta = torch.atan2(cartesian[:, 1, None], cartesian[:, 0, None])
    return torch.cat([radius, theta], dim=-1)


class SC2PosRegression(nn.Module):
    """Combination of regression loss for position prediction and cross-entropy of position
    assignment."""

    def __init__(
        self,
        loss_type: Literal["l1", "l2"],
        include_targets: bool,
        relative_pos: bool,
        cartesian_pos: bool,
        true_pos_weight: float | None,
    ):
        super().__init__()
        self.loss = {"l1": l1_loss, "l2": mse_loss}[loss_type]
        self.include_targets = include_targets
        self.relative_pos = relative_pos
        self.cartesian_pos = cartesian_pos
        if true_pos_weight is None:
            self._pos_weight = None
        else:
            self.register_buffer(
                "_pos_weight",
                torch.tensor(true_pos_weight, dtype=torch.float32),
                persistent=False,
            )

    def apply_cartesian_loss(self, pred: Tensor, pos_gt: Tensor):
        """Apply distance loss between prediction and ground truth"""
        if pred.shape[-1] > 2:  # 1-ch variance will be broadcasted
            pos_xy = pred[..., :2]
            pos_var = pred[..., 2:]
            loss = (
                torch.exp(pos_var) * self.loss(pos_xy, pos_gt, reduction="none")
                - pos_var
            )
            loss = loss.mean()
        else:
            loss = self.loss(pred, pos_gt)
        return loss

    def apply_polar_loss(self, pred: Tensor, pos_gt: Tensor):
        """Apply distance loss between predicted (r,sin(t),cos(t)) and (r,t)"""
        pos_rt = _convert_cart_to_polar(pos_gt[..., :2])

        angle_loss = 1 - torch.cos(
            pos_rt[..., 1] - torch.atan2(pred[..., 1], pred[..., 2])
        )
        if pred.shape[-1] == 3:  # r,sin(t),cos(t)
            loss = self.loss(pred[..., 0], pos_rt[..., 0])
            loss += torch.mean(angle_loss)
        elif pred.shape[-1] == 5:  # r,sin(t),cos(t),var_r,var_t
            var_r = pred[..., 3]
            var_t = pred[..., 4]
            loss = (
                torch.exp(var_r)
                * self.loss(pred[..., 0], pos_rt[..., 0], reduction="none")
                - var_r
            ).mean()
            loss += torch.mean(torch.exp(var_t) * angle_loss - var_t)
        else:
            raise RuntimeError(f"Incorrect prediction dimensionality: {pred.shape[-1]}")

        return loss

    def forward(self, model_out: dict[str, Tensor], batch_data: TorchSC2Data):
        """Apply positional loss to model predictions"""
        loss: dict[str, Tensor] = {}

        pos_gt, mask = get_unit_target_positions(batch_data)
        pos_gt = pos_gt[mask]

        if self.relative_pos:
            pos_gt = _make_relative(pos_gt, mask, batch_data)

        if self.cartesian_pos:
            loss["pos-l2"] = self.apply_cartesian_loss(
                model_out["position"][mask], pos_gt
            )
        else:
            loss["pos-l2"] = self.apply_polar_loss(model_out["position"][mask], pos_gt)

        # Re-use mask as position logit
        if "pos-logit" in model_out:
            _units_mask = batch_data.units_mask  # we still mask out invalid units
            loss["pos-logit"] = binary_cross_entropy_with_logits(
                model_out["pos-logit"][_units_mask],
                mask[_units_mask].unsqueeze(-1).to(torch.float32),
                pos_weight=self._pos_weight,
            )

        if self.include_targets:
            pos_gt, mask = _get_target_positions(batch_data)
            if self.relative_pos:
                pos_gt = _make_relative(pos_gt, mask, batch_data)

            if self.cartesian_pos:
                loss["pos-l2"] += self.apply_cartesian_loss(
                    model_out["position"][mask], pos_gt
                )
            else:
                loss["pos-l2"] += self.apply_polar_loss(
                    model_out["position"][mask], pos_gt
                )

        return loss


class SC2PosCrossEntropy(nn.Module):
    """Cross-entropy loss for both position value and position assignment.
    Based on "Stop Regressing: Training Value Functions via Classification for Scalable Deep RL"
    """

    def __init__(
        self,
        num_bins: int,
        sigma: float,
        include_targets: bool,
        relative_pos: bool,
        cartesian_pos: bool,
        true_pos_weight: float | None,
    ):
        super().__init__()
        self.include_targets = include_targets
        self.relative_pos = relative_pos
        self.cartesian_pos = cartesian_pos
        self._pos_weight: Tensor | None
        if true_pos_weight is None:
            self._pos_weight = None
        else:
            self.register_buffer(
                "_pos_weight",
                torch.tensor(true_pos_weight, dtype=torch.float32),
                persistent=False,
            )

        self._divisor: Tensor
        self.register_buffer(
            "_divisor",
            torch.sqrt(torch.tensor(2.0, dtype=torch.float32)) * sigma,
            persistent=False,
        )
        self.support: Tensor
        self.support_angle: Tensor | None
        if cartesian_pos:
            # relative may need to reach from one corner to the other
            limit = 2 if relative_pos else 1
            self.register_buffer(
                "support",
                torch.linspace(-limit, limit, num_bins + 1, dtype=torch.float32),
                persistent=False,
            )
            self.support_angle = None
        else:
            # Over by pi each end to enable angle wrapping in _transform_to_probs_w_wrap
            limit = 2 * torch.pi
            self.register_buffer(
                "support_angle",
                torch.linspace(-limit, limit, 2 * num_bins + 1, dtype=torch.float32),
                persistent=False,
            )
            # Radius "support" only positive and range up to opposite corners
            self.register_buffer(
                "support",
                torch.linspace(
                    0, torch.sqrt(torch.tensor(8)), num_bins + 1, dtype=torch.float32
                ),
                persistent=False,
            )

    def _transform_to_probs(self, target: Tensor) -> Tensor:
        """Transform target value to probability bins"""
        cdf_evals: Tensor = torch.special.erf(
            (self.support - target.unsqueeze(-1)) / self._divisor
        )
        z = cdf_evals[..., -1] - cdf_evals[..., 0]
        bin_probs = torch.diff(cdf_evals, n=1, dim=-1)
        return bin_probs / z.unsqueeze(-1)

    def _transform_to_probs_w_wrap(self, target: Tensor) -> Tensor:
        """Transform target value to probability bins but wrap"""
        assert self.support_angle is not None
        cdf_evals: Tensor = torch.special.erf(
            (self.support_angle - target.unsqueeze(-1)) / self._divisor
        )
        z = cdf_evals[..., -1] - cdf_evals[..., 0]
        bin_probs = torch.diff(cdf_evals, n=1, dim=-1)
        double_width = bin_probs / z.unsqueeze(-1)
        perimeter = double_width.shape[-1] // 4
        result = double_width[..., perimeter:-perimeter]
        result[..., -perimeter:] += double_width[..., :perimeter]
        result[..., :perimeter] += double_width[..., -perimeter:]
        return result

    def get_gt_label(self, gt_cart: Tensor):
        """Transform the gt label"""
        if self.cartesian_pos:
            labels = self._transform_to_probs(gt_cart)
        else:
            gt_polar = _convert_cart_to_polar(gt_cart)
            polar_labels = [
                self._transform_to_probs(gt_polar[..., 0]),
                self._transform_to_probs_w_wrap(gt_polar[..., 1]),
            ]
            labels = torch.stack(polar_labels, dim=1)

        # Flatten to [N*2, n_bins] so it can be batched with gt which is reshaped to this
        return labels.flatten(0, -2)

    def reshape_prediction(self, pred: Tensor):
        """Reshape prediction from logits fused in the last dim [N, [n_bins, n_bins]] to batched
        [N * 2, n_bins]"""
        assert pred.ndim == 2, f"Got {pred.ndim=}, expected 2"
        return pred.reshape(pred.shape[0], 2, pred.shape[1] // 2).flatten(0, -2)

    def forward(self, model_out: dict[str, Tensor], batch_data: TorchSC2Data):
        """Apply positional loss to model predictions"""
        loss: dict[str, Tensor] = {}

        pos_gt, mask = get_unit_target_positions(batch_data)
        pos_gt = pos_gt[mask]
        if self.relative_pos:
            pos_gt = _make_relative(pos_gt, mask, batch_data)

        loss["pos-ce"] = cross_entropy(
            self.reshape_prediction(model_out["position"][mask]),
            self.get_gt_label(pos_gt),
        )

        # Re-use mask as position logit
        if "pos-logit" in model_out:
            _units_mask = batch_data.units_mask  # we still mask out invalid units
            loss["pos-logit"] = binary_cross_entropy_with_logits(
                model_out["pos-logit"][_units_mask],
                mask[_units_mask].unsqueeze(-1).to(torch.float32),
                pos_weight=self._pos_weight,
            )

        if self.include_targets:
            pos_gt, mask = _get_target_positions(batch_data)
            if self.relative_pos:
                pos_gt = _make_relative(pos_gt, mask, batch_data)
            loss["pos-ce"] += cross_entropy(
                self.reshape_prediction(model_out["position"][mask]),
                self.get_gt_label(pos_gt),
            )

        return loss


@dataclass
@REGISTRY.register_module("sc2-pos-target")
class SC2PosConfig(LossConfig):
    loss_type: Literal["l1", "l2"] = "l1"
    include_targets: bool = False
    relative_pos: bool = False
    cartesian_pos: bool = True
    use_regression: bool = True
    num_bins: int = 0
    sigma: float = 0.2
    true_pos_weight: float | None = None

    @classmethod
    def from_config(cls, config: ExperimentInitConfig, idx: int, **kwargs):
        model_args: dict[str, Any] = config.model[0].args["pos_decoder"]["args"]
        loss_args = config.criterion[idx].args
        loss_args["relative_pos"] = model_args.get("relative_pos", False)
        loss_args["cartesian_pos"] = model_args.get("cartesian", True)
        loss_args["num_bins"] = model_args.get("logit_out", 0)
        loss_args["use_regression"] = loss_args["num_bins"] == 0
        return cls(**loss_args, **kwargs)

    def get_instance(self, *args, **kwargs) -> Any:
        if not self.relative_pos and not self.cartesian_pos:
            raise RuntimeError(
                "Polar coordiantes don't make sense for unit decoder in global frame"
            )
        if self.use_regression:
            return SC2PosRegression(
                self.loss_type,
                self.include_targets,
                self.relative_pos,
                self.cartesian_pos,
                self.true_pos_weight,
            )
        return SC2PosCrossEntropy(
            self.num_bins,
            self.sigma,
            self.include_targets,
            self.relative_pos,
            self.cartesian_pos,
            self.true_pos_weight,
        )
