import enum
from dataclasses import dataclass
from itertools import pairwise
from typing import Any, Literal

import torch
from konductor.data import get_dataset_properties
from konductor.init import ExperimentInitConfig, ModuleInitConfig
from konductor.models import MODEL_REGISTRY
from konductor.models._pytorch import TorchModelConfig
from konductor.registry import Registry
from konductor.utilities import comm
from sc2_serializer.unit_features import UnitOH
from torch import Tensor, nn
from torchvision.models.resnet import BasicBlock, ResNet, ResNet18_Weights

from ..dataset.sc2_dataset import TorchSC2Data
from . import perceiver_io as pio
from ._encoders import EncodingMethod
from ._iadapter import (
    AgentIA,
    generate_position_encodings,
    generate_positions_for_encoding,
)
from .sc2_encoder import UNIT_ENCODERS

MINIMAP_ENCODERS = Registry("minimap-encoders")


@MINIMAP_ENCODERS.register_module("v1")
class MinimapEncoderV1(ResNet):
    """Minimap encoder is basically just ResNet18 with 1ch in and latent_dim out,"""

    def __init__(
        self,
        latent_dim: int,
        output_size: tuple[int, int],
        static_pos: bool = False,
        input_ch: int = 1,
        pretrained: bool = False,
        norm_output: bool = False,
    ) -> None:
        super().__init__(BasicBlock, [2, 2, 2, 2])
        if pretrained:
            self._load_pretrained()

        self.fc = nn.Linear(512 * BasicBlock.expansion, latent_dim)
        self.avgpool = nn.AdaptiveMaxPool2d(output_size)
        self.conv1 = nn.Conv2d(
            input_ch, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        nn.init.kaiming_normal_(self.conv1.weight, mode="fan_out", nonlinearity="relu")

        pos = generate_positions_for_encoding(output_size)
        enc = generate_position_encodings(pos, latent_dim // 4, include_positions=False)
        enc = enc.flatten(0, 1)  # [H*W,C]
        self.position_encoding: Tensor
        if static_pos:
            self.register_buffer("position_encoding", enc, persistent=False)
        else:
            self.position_encoding = nn.Parameter(enc)

        self.out_norm = nn.LayerNorm(latent_dim) if norm_output else None

    def _load_pretrained(self):
        # Only main local thread downloads weights, others wait
        if comm.get_local_rank() == 0:
            weights = ResNet18_Weights.IMAGENET1K_V1.get_state_dict(
                progress=True, check_hash=True
            )
            comm.synchronize()
        else:
            comm.synchronize()
            weights = ResNet18_Weights.IMAGENET1K_V1.get_state_dict(
                progress=True, check_hash=True
            )
        self.load_state_dict(weights)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.flatten(2).transpose(1, 2)  # [B,C,H,W] -> [B,H*W,C]
        x = self.fc(x)
        if self.out_norm is not None:
            x = self.out_norm(x)
        x = x + self.position_encoding[None]

        return x


@MINIMAP_ENCODERS.register_module("v2")
class MinimapEncoderV2(nn.Module):
    """Smaller and simpler minimap encoder"""

    def __init__(
        self,
        channels: list[int],
        latent_dim: int,
        output_size: tuple[int, int],
        input_ch: int = 1,
        pos_type: Literal["learn", "static", "proj"] = "proj",
    ):
        super().__init__()
        self.in_layer = nn.Sequential(
            nn.Conv2d(
                input_ch, channels[0], kernel_size=5, stride=2, padding=2, bias=False
            ),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        def make_layer(in_ch: int, out_ch: int):
            """Modeled after resnet.BasicBlock"""
            proj_identity = (
                nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else None
            )
            return BasicBlock(in_ch, out_ch, downsample=proj_identity)

        self.layers = nn.Sequential(*(make_layer(*args) for args in pairwise(channels)))
        self.avgpool = nn.AdaptiveAvgPool2d(output_size)
        self.fc = nn.Linear(channels[-1], latent_dim)
        self.out_norm = nn.LayerNorm(latent_dim)

        pos = generate_positions_for_encoding(output_size)
        enc = generate_position_encodings(pos, latent_dim // 4, include_positions=False)
        enc = enc.flatten(0, 1)  # [H*W,C]
        self.position_encoding: Tensor
        if pos_type in {"static", "proj"}:
            self.register_buffer("position_encoding", enc, persistent=False)
        else:
            self.position_encoding = nn.Parameter(enc)

        self.pos_proj = (
            nn.Linear(latent_dim, latent_dim, bias=False)
            if pos_type == "proj"
            else None
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize model weights, similar to ResNet + zero_init_residual"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

        if self.pos_proj is not None:
            torch.nn.init.eye_(self.pos_proj.weight)

    def forward(self, x: Tensor):
        x = self.in_layer(x)
        x = self.layers(x)
        x = self.avgpool(x)
        x = x.flatten(2).transpose(1, 2)  # [B,C,H,W] -> [B,H*W,C]
        x = self.fc(x)
        x = self.out_norm(x)

        pos = self.position_encoding[None]
        if self.pos_proj is not None:
            pos = self.pos_proj(pos)
        x = x + pos

        return x


UNIT_DECODERS = Registry("unit-decoders")


@torch.jit.script
def correlation(source: Tensor, target: Tensor, mask: Tensor):
    """Correlate source vectors to target vectors. Mask included

    Args:
        source (Tensor): Source to find the assignment of [B,N,C]
        target (Tensor): Target tensor to assign to [B,N,C]
        mask (Tensor): True if target is invalid
        tau (Tensor): Temperature of the softmax operation

    Returns:
        Tensor: Assignment score from source to target tensor
    """
    attn = torch.einsum("bqc,bkc->bqk", source, target)
    mask = mask.unsqueeze(1).expand_as(attn)
    attn[mask] = -torch.inf
    return attn


@UNIT_DECODERS.register_module("v1")
class UnitDecoderV1(nn.Module):
    """Decode the assignment between units and targets"""

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        hidden_dim: int,
        num_layers: int = 1,
        num_heads: int = 4,
        dropout: float = 0.0,
        post_project: bool = False,
        layer_norm: bool = False,
        temperature_out: float = 1.0,
        latent_proj: bool = False,
        logit_out: bool = False,
    ) -> None:
        super().__init__()
        self.logit_out = logit_out
        self.apply_latent_context = pio.CrossAttentionBlock(
            num_layers,
            unit_dim if post_project else hidden_dim,
            latent_dim,
            num_heads,
            dropout,
        )
        self.unit_proj = nn.Linear(unit_dim, hidden_dim, bias=False)
        self.unit_norm = nn.LayerNorm(hidden_dim) if layer_norm else None

        self.target_proj = nn.Linear(unit_dim, hidden_dim, bias=False)
        self.target_norm = nn.LayerNorm(hidden_dim) if layer_norm else None

        self.latent_proj = pio.mlp(latent_dim) if latent_proj else None

        self.post_project = post_project
        self.null_assignment = nn.Parameter(torch.empty(1, hidden_dim))
        self.temperature: Tensor
        self.register_buffer(
            "temperature",
            torch.tensor(temperature_out, dtype=torch.float32),
            persistent=False,
        )
        self._init_params()

    def _init_params(self):
        nn.init.xavier_uniform_(self.unit_proj.weight)
        nn.init.xavier_uniform_(self.target_proj.weight)
        nn.init.xavier_uniform_(self.null_assignment)

    def forward(
        self, latent: Tensor, units: Tensor, targets: Tensor, targets_mask: Tensor
    ):
        """Units mask is not needed, doesn't hurt to apply on invalid padding data"""
        batch_sequence = latent.ndim == 4
        if batch_sequence:
            seq, batch = latent.shape[:2]
            latent = latent.view(-1, *latent.shape[2:])
            units = units.view(-1, *units.shape[2:])
            targets = targets.view(-1, *targets.shape[2:])
            targets_mask = targets_mask.reshape(-1, *targets_mask.shape[2:])

        if self.latent_proj is not None:
            latent = self.latent_proj(latent)

        if not self.post_project:
            units = self.unit_proj(units)
        units = self.apply_latent_context(units, latent)
        if self.post_project:
            units = self.unit_proj(units)
        targets = self.target_proj(targets)
        # Add null assignment token as first token
        null_assignment = self.null_assignment[None].expand(targets.shape[0], -1, -1)
        targets = torch.cat((null_assignment, targets), dim=1)
        targets_mask = torch.cat(
            (torch.ones_like(targets_mask[:, 0, None]), targets_mask), dim=1
        )

        if self.unit_norm is not None:
            assert self.target_norm is not None
            units = self.unit_norm(units)
            targets = self.target_norm(targets)

        # Calculate assignment probability
        assignment = correlation(units, targets, ~targets_mask)
        if not self.logit_out:
            assignment = torch.softmax(assignment / self.temperature, dim=-1)

        if batch_sequence:
            assignment = assignment.view(seq, batch, *assignment.shape[1:])

        return {"unit-target": assignment}


POS_DECODER = Registry("pos-decoders")

# _FILE = "growth.csv"
# _IDX = 0


@POS_DECODER.register_module("v1")
class PosDecoderV1(nn.Module):
    """Decode estimate of unit position targets by regressing the individual units

    Args:
        unit_dim (int): Input unit feature vector dimension
        latent_dim (int): Input latent feature vector dimension
        hidden_dim (int): Internal hidden dimension
        num_layers (int, optional): Number of cross-attention layers from unit to latent. Defaults to 4.
        num_heads (int, optional): Number of heads in MHA. Defaults to 4.
        dropout (float, optional): Dropout used in MHA residule. Defaults to 0.0.
        variance_out (bool, optional): Output predicted variance of prediction value. Defaults to True.
        layer_norm (bool, optional): Apply layer_norm to latent at input. Defaults to False.
        decode_logit (bool, optional): Return logit that corresponds to likelihood that position value is valid. Defaults to False.
        relative_out (bool, optional): Output is in unit's frame of reference. Defaults to False.
        cartesian (bool, optional): Model output is x y, otherwise yaw theta (polar). Defaults to True.
        logit_out (int, optional): Output is categorial HL-Gauss distribution, rather than value. Defaults to 0.
    """

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        hidden_dim: int,
        num_layers: int = 4,
        num_heads: int = 4,
        dropout: float = 0.0,
        variance_out: bool = True,
        layer_norm: bool = False,
        decode_logit: bool = False,
        relative_pos: bool = False,
        cartesian: bool = True,
        logit_out: int = 0,
    ):
        super().__init__()
        self.relative_pos = relative_pos
        self.cartesian = cartesian
        if not cartesian and not relative_pos:
            raise RuntimeError(
                "Yaw/theta output doesn't make sense in global frame for unit decoder"
            )

        self.unit_proj = nn.Linear(unit_dim, hidden_dim)
        self.latent_norm = nn.LayerNorm(latent_dim) if layer_norm else None
        self.latent_proj = nn.Linear(latent_dim, hidden_dim)
        self.encoder = pio.CrossAttentionBlock(
            num_layers, hidden_dim, hidden_dim, num_heads, dropout
        )
        self.logit_decoder = nn.Linear(hidden_dim, 1) if decode_logit else None
        if logit_out > 0:
            # either x,y or r,t where logit_out is resolution
            self.decoder = nn.Linear(hidden_dim, logit_out * 2)
            if variance_out:
                raise RuntimeError("Can't output variance of logit output")
        else:
            out_ch = 2 if cartesian else 3  # polar is [r,sin(t),cos(t)]
            if variance_out:
                out_ch += 2  # last two channel are var of xy or rt
            self.decoder = nn.Linear(hidden_dim, out_ch, bias=False)
        self._init_params()

    def _init_params(self):
        nn.init.xavier_uniform_(self.unit_proj.weight)
        nn.init.xavier_uniform_(self.latent_proj.weight)
        nn.init.constant_(self.unit_proj.bias, 0)
        nn.init.constant_(self.latent_proj.bias, 0)
        nn.init.xavier_uniform_(self.decoder.weight)

    def forward(self, latent: Tensor, units: Tensor):
        batch_sequence = latent.ndim == 4
        if batch_sequence:
            seq, batch = latent.shape[:2]
            latent = latent.view(-1, *latent.shape[2:])
            units = units.view(-1, *units.shape[2:])

        # global _IDX
        # if _IDX == 0:
        #     with open(_FILE, "w", encoding="utf-8") as f:
        #         f.write("idx,name,min,max,mean,var\n")
        # @torch.no_grad()
        # def write_tensor_stats(name: str, x: Tensor):
        #     with open(_FILE, "a", encoding="utf-8") as f:
        #         f.write(
        #             f"{_IDX},{name},{x.min().item():.2f},{x.max().item():.2f},"
        #             f"{x.mean().item():.2f},{x.var().item():.2f}\n"
        #         )

        units = self.unit_proj(units)
        if self.latent_norm is not None:
            latent = self.latent_norm(latent)
        latent = self.latent_proj(latent)
        inter = self.encoder(units, latent)
        position_target: Tensor = self.decoder(inter)

        output = {"position": position_target}
        if self.logit_decoder is not None:
            output["pos-logit"] = self.logit_decoder(inter)

        if batch_sequence:
            output = {k: v.reshape(seq, batch, *v.shape[1:]) for k, v in output.items()}

        # _IDX += 1
        return output


class MinimapCtxType(enum.Enum):
    """Method used to include minimap context"""

    append_start = "append_start"
    pre_decoding = "pre_decoding"


class SC2IntentPredictor(nn.Module):
    """Predict the assignment between SC2 units or motion targets"""

    def __init__(
        self,
        latent_num: int,
        latent_dim: int,
        unit_adapter: AgentIA,
        unit_encoder: nn.Module,
        unit_decoder: nn.Module | None,
        pos_decoder: nn.Module | None,
        minimap_encoder: nn.Module | None,
        minimap_ctx: MinimapCtxType,
    ):
        super().__init__()
        self.unit_adapter = unit_adapter
        self.unit_update = unit_encoder
        self.minimap_encoder = minimap_encoder
        self.unit_decoder = unit_decoder
        self.pos_decoder = pos_decoder
        self.minimap_ctx = minimap_ctx

        self.latent = nn.Parameter(
            torch.empty(latent_num, latent_dim).normal_().clamp_(-2, 2)
        )
        self._latent = torch.empty(0)

    def forward(self, data: TorchSC2Data):
        assert data.enemy_mask is not None
        assert data.enemy_units is not None

        latent = self.latent.expand(data.batch_size, -1, -1)
        if self.minimap_encoder is not None:
            assert data.minimap is not None, "Missing minimap data"
            with torch.profiler.record_function("minimap-encoder"):
                minimap: Tensor = self.minimap_encoder(data.minimap)
            if self.minimap_ctx is MinimapCtxType.append_start:
                latent = torch.cat([latent, minimap], dim=1)

        units, units_mask = self.unit_adapter(data.units, data.units_mask)
        targets, targets_mask = self.unit_adapter(data.enemy_units, data.enemy_mask)

        with torch.profiler.record_function("unit-update"):
            latent: Tensor = self.unit_update(
                latent, units, units_mask, targets, targets_mask
            )

        if (
            self.minimap_encoder is not None
            and self.minimap_ctx is MinimapCtxType.pre_decoding
        ):
            latent = torch.cat(
                [latent, minimap.expand(latent.shape[0], -1, -1, -1)], dim=2
            )

        outputs: dict[str, Tensor] = {}

        if self.unit_decoder is not None:
            with torch.profiler.record_function("unit-decoder"):
                outputs.update(self.unit_decoder(latent, units, targets, targets_mask))

        if self.pos_decoder is not None:
            with torch.profiler.record_function("pos-decoder"):
                outputs.update(self.pos_decoder(latent, units))

        return outputs

    def inc_reset(self, batch_size: int):
        """Reset incremental mode latent state"""
        self._latent = self.latent.clone().unsqueeze(0).expand(batch_size, -1, -1)

    def inc_minimap(self, minimap: Tensor):
        """Add minimap context in incremental mode"""
        assert self.minimap_encoder is not None
        if minimap.ndim == 3:
            minimap = minimap.unsqueeze(0)
        minimap = self.minimap_encoder(minimap)
        self._latent = torch.cat([self._latent, minimap], dim=1)

    def inc_forward(self, data: TorchSC2Data):
        """Add units observation and decode prediction in incremental mode"""
        units, units_mask = self.unit_adapter(data.units, data.units_mask)
        if data.enemy_units is not None:
            assert data.enemy_mask is not None
            targets, targets_mask = self.unit_adapter(data.enemy_units, data.enemy_mask)
        else:
            targets, targets_mask = units, units_mask

        self._latent = self.unit_update.inc_forward(
            self._latent, units, units_mask, targets, targets_mask
        )

        outputs: dict[str, Tensor] = {}
        if self.unit_decoder is not None:
            outputs.update(
                self.unit_decoder(self._latent[-1], units, targets, targets_mask)
            )

        if self.pos_decoder is not None:
            outputs.update(self.pos_decoder(self._latent[-1], units))

        return outputs


@dataclass
@MODEL_REGISTRY.register_module("sc2-intent-predictor")
class SC2IntentModelCfg(TorchModelConfig):
    latent_num: int
    latent_dim: int
    unit_adapter: dict[str, Any]
    unit_encoder: ModuleInitConfig
    minimap_ctx: MinimapCtxType = MinimapCtxType.append_start
    minimap_encoder: ModuleInitConfig | None = None
    unit_decoder: ModuleInitConfig | None = None
    pos_decoder: ModuleInitConfig | None = None

    latent_type: Any | None = None  # Ignore old arg

    @classmethod
    def from_config(cls, config: ExperimentInitConfig, idx: int = 0):
        props = get_dataset_properties(config)
        model_cfg = config.model[idx].args

        unit_adapter: dict = model_cfg["unit_adapter"]
        unit_adapter["pos_feats"] = props["num_pos_features"]
        unit_adapter["other_feats"] = props["num_other_features"]
        if "input_mode" not in unit_adapter:
            unit_adapter["input_mode"] = "fpos"

        if "embedding" == unit_adapter.get("class_mode"):
            assert (
                props["unit_features"][-1] is UnitOH.unitType
            ), "Unit type should be last feature"
            assert (
                props["contiguous_unit_type"] is True
            ), "Unit type should be remapped as contiguous"
            unit_adapter["n_classes"] = props["n_classes"]

        if model_cfg["unit_encoder"]["type"] in {"transformer", "transformer2"}:
            model_cfg["unit_encoder"]["args"]["max_time"] = props["clip_length"]

        return super().from_config(config)

    def __post_init__(self):
        if isinstance(self.unit_encoder, dict):
            self.unit_encoder = ModuleInitConfig(**self.unit_encoder)
        if isinstance(self.unit_decoder, dict):
            self.unit_decoder = ModuleInitConfig(**self.unit_decoder)
        if isinstance(self.pos_decoder, dict):
            self.pos_decoder = ModuleInitConfig(**self.pos_decoder)
        if isinstance(self.minimap_encoder, dict):
            self.minimap_encoder = ModuleInitConfig(**self.minimap_encoder)
        if isinstance(self.minimap_ctx, str):
            self.minimap_ctx = MinimapCtxType[self.minimap_ctx.lower()]

        for module in [
            self.unit_encoder,
            self.unit_decoder,
            self.pos_decoder,
            self.minimap_encoder,
        ]:
            if module is None:
                continue
            module.args["latent_dim"] = self.latent_dim

        if self.unit_encoder.type in {"recurrent", "mamba"}:
            self.unit_encoder.args["latent_num"] = self.latent_num

        # No need for transformer2 after refactor
        if self.unit_encoder.type == "transformer2":
            self.unit_encoder.type = "transformer"
            self.unit_encoder.args["encoding_method"] = EncodingMethod.fused

        if isinstance(self.unit_encoder.args.get("encoding_method"), str):
            self.unit_encoder.args["encoding_method"] = EncodingMethod[
                self.unit_encoder.args["encoding_method"]
            ]

        assert self.unit_decoder or self.pos_decoder, "Model has no decoder!"

    def get_instance(self):
        unit_adapter = AgentIA(**self.unit_adapter)
        unit_encoder = UNIT_ENCODERS[self.unit_encoder.type](
            unit_dim=unit_adapter.out_channels,
            **self.unit_encoder.args,
        )

        if self.unit_decoder is None:
            unit_decoder = None
        else:
            unit_decoder = UNIT_DECODERS[self.unit_decoder.type](
                unit_dim=unit_adapter.out_channels, **self.unit_decoder.args
            )

        if self.pos_decoder is None:
            pos_decoder = None
        else:
            pos_decoder = POS_DECODER[self.pos_decoder.type](
                unit_dim=unit_adapter.out_channels, **self.pos_decoder.args
            )

        if self.minimap_encoder is None:
            minimap_encoder = None
        else:
            minimap_encoder = MINIMAP_ENCODERS[self.minimap_encoder.type](
                **self.minimap_encoder.args
            )
        return self.init_auto_filter(
            SC2IntentPredictor,
            unit_adapter=unit_adapter,
            unit_encoder=unit_encoder,
            unit_decoder=unit_decoder,
            pos_decoder=pos_decoder,
            minimap_encoder=minimap_encoder,
        )
