"""Common methods to transform position predictions to global cartesian format"""

from typing import Any

import torch
from konductor.init import ExperimentInitConfig
from torch import Tensor

from ..dataset.sc2_dataset import TorchSC2Data


def calculate_angle_from_prob(probs: Tensor) -> Tensor:
    """Calculate the angle value from predicted logits with care towards
    wrapping angles correctly"""
    probs_ = torch.empty_like(probs)
    pivot_indexes = torch.argmax(probs, dim=-1)
    half_idx = probs.shape[-1] // 2

    supports = torch.empty(
        probs.shape[0], probs.shape[1] + 1, dtype=probs.dtype, device=probs.device
    )
    for idx, pivot_index in enumerate(pivot_indexes.tolist()):
        pivot_value = torch.pi * (2 * pivot_index / probs.shape[-1] - 1)
        torch.linspace(
            pivot_value - torch.pi,
            pivot_value + torch.pi,
            steps=probs.shape[-1] + 1,
            dtype=probs.dtype,
            device=probs.device,
            out=supports[idx],
        )
        prob_idx = pivot_index - half_idx
        torch.cat(
            [probs[idx, prob_idx:], probs[idx, :prob_idx]],
            dim=0,
            out=probs_[idx],
        )

    centers = (supports[:, :-1] + supports[:, 1:]) / 2
    pred_value = torch.sum(centers * probs_, dim=-1)
    return pred_value


def _get_position_from_value(batch_data: TorchSC2Data):
    """Get valid position ground truth from batch_data as well as mask for prediction"""
    assert batch_data.positions is not None
    mask = (batch_data.positions > -1).all(dim=-1)
    mask &= (batch_data.positions < 1).all(dim=-1)
    mask &= batch_data.units_mask
    return batch_data.positions, mask


def _get_position_from_unique(batch_data: TorchSC2Data):
    """Gather the ground truth position values and masking"""
    assert batch_data.position_targets is not None
    assert batch_data.positions_unique is not None
    # prepend the positions with dummy data so gather has valid indexing
    positions = torch.cat(
        [
            torch.zeros_like(batch_data.positions_unique[:, :, 0, None]),
            batch_data.positions_unique,
        ],
        dim=2,
    )
    # Gather the actual position values based on the position index assignment
    pos_gt = torch.gather(
        positions,
        2,
        batch_data.position_targets[..., None].expand(-1, -1, -1, 2) + 1,
    )
    # Mask is a combination of valid unit + valid position target
    mask = batch_data.units_mask & (batch_data.position_targets != -1)
    return pos_gt, mask


def get_unit_target_positions(data: TorchSC2Data):
    """Get the unit target position values and a mask where this value is valid"""
    if data.positions is not None:
        pos_gt, mask = _get_position_from_value(data)
    else:
        pos_gt, mask = _get_position_from_unique(data)
    return pos_gt, mask


class PositionTransform:
    """Contains configuration for how to transform prediction to global cartesian coordinates"""

    @classmethod
    def from_config(cls, cfg: ExperimentInitConfig):
        """Create PositionTransform based on model config in experiment config"""
        pos_dec_cfg: dict[str, Any] = cfg.model[0].args["pos_decoder"]["args"]
        rel_pos = pos_dec_cfg.get("relative_pos", False)
        polar_pos = not pos_dec_cfg.get("cartesian", True)
        pos_bins = pos_dec_cfg.get("logit_out", 0)
        return cls(rel_pos, polar_pos, pos_bins)

    def __init__(
        self, rel_pos: bool = False, polar_pos: bool = False, pos_bins: int = 0
    ):
        self.rel_pos = rel_pos
        self.polar_pos = polar_pos
        self.pos_bins = pos_bins
        if pos_bins != 0:
            self._initialize_supports()
        else:
            self.support = None

    def _initialize_supports(self):
        if self.polar_pos:
            self.support = torch.linspace(
                0, torch.sqrt(torch.tensor(8)), self.pos_bins + 1, dtype=torch.float32
            )
        else:
            limit = 2 if self.rel_pos else 1
            self.support = torch.linspace(
                -limit, limit, self.pos_bins + 1, dtype=torch.float32
            )
        if torch.cuda.is_available():
            self.support = self.support.cuda()

    def get_value_from_logits(self, pred: Tensor):
        """Convert logit prediction to values"""
        assert self.support is not None
        assert pred.ndim == 2, f"Got {pred.ndim=}, expected 2"
        pred = torch.softmax(pred.reshape(pred.shape[0], 2, pred.shape[1] // 2), dim=-1)
        centers = (self.support[:-1] + self.support[1:]) / 2
        if self.polar_pos:
            pred_radius = torch.sum(pred[:, 0] * centers, dim=-1)
            pred_theta = calculate_angle_from_prob(pred[:, 1])
            pred_value = pred_radius[:, None] * torch.stack(
                [torch.cos(pred_theta), torch.sin(pred_theta)], dim=-1
            )
        else:
            pred_value = torch.sum(pred * centers, dim=-1)

        return pred_value

    def __call__(self, pred: Tensor, units_pos: Tensor | None = None):
        """Convert prediction format to [-1,1] normalized global cartesian coordinates"""
        if self.support is not None:
            pred_cart = self.get_value_from_logits(pred)
        elif self.polar_pos:
            pred_theta = torch.atan2(pred[..., 1], pred[..., 2])
            pred_cart = pred[..., 0, None] * torch.stack(
                (torch.cos(pred_theta), torch.sin(pred_theta)), dim=-1
            )
        else:
            pred_cart = pred[:, :2]

        if self.rel_pos:
            assert units_pos is not None, "units_pos required for relative position"
            # Visually, it seems like relative position predicts 2x magnitude
            # pred_cart /= 2
            pred_cart += units_pos

        return pred_cart
