from typing import List, Optional, Callable, Any

import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import SwinTransformer, WeightsEnum
from torchvision.models._api import register_model
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.swin_transformer import (
    PatchMerging,
    Swin_V2_S_Weights,
    SwinTransformerBlockV2,
    PatchMergingV2,
)

from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomSwinTransformer",
    "base_swin_v2_s",
    "own_swin_v2_s",
]


class CustomSwinTransformer(SwinTransformer):
    def __init__(
        self,
        patch_size: List[int],
        embed_dim: int,
        depths: List[int],
        num_heads: List[int],
        window_size: List[int],
        mlp_ratio: float = 4.0,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        stochastic_depth_prob: float = 0.1,
        num_classes: int = 1000,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        block: Optional[Callable[..., nn.Module]] = None,
        downsample_layer: Callable[..., nn.Module] = PatchMerging,
        gumbel_dim: int = 1,
        tau: float = 1,
    ):
        super().__init__(
            patch_size=patch_size,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            attention_dropout=attention_dropout,
            stochastic_depth_prob=stochastic_depth_prob,
            num_classes=num_classes,
            norm_layer=norm_layer,
            block=block,
            downsample_layer=downsample_layer,
        )

        num_features = embed_dim * 2 ** (len(depths) - 1)

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        self.unitary_matrix = UnitaryMatrixMultiplication(num_features)
        self.head = NonNegativeLinear(num_features, num_classes)

        self.changed_layers = ["unitary_matrix", "head"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.norm(x)
        x = self.permute(x)  # B H W C -> B C H W

        ########################################
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.head(x)
        return x


def own_swin_transformer(
    patch_size: List[int],
    embed_dim: int,
    depths: List[int],
    num_heads: List[int],
    window_size: List[int],
    stochastic_depth_prob: float,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> CustomSwinTransformer | SwinTransformer:
    model = eval(kwargs.pop("model", "CustomSwinTransformer"))(
        patch_size=patch_size,
        embed_dim=embed_dim,
        depths=depths,
        num_heads=num_heads,
        window_size=window_size,
        stochastic_depth_prob=stochastic_depth_prob,
        **kwargs,
    )
    if type(model) is SwinTransformer:
        model.changed_layers = ["head"]
        print("\033[0;1;34mUse base SwinTransformer\033[0m")

    if weights is not None:
        pretrained_dict = weights.get_state_dict(progress=progress, check_hash=True)

        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in pretrained_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                pretrained_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        missing_keys, unexpected_keys = model.load_state_dict(
            pretrained_dict, strict=False
        )
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )

    return model


@register_model()
@handle_legacy_interface(weights=("pretrained", Swin_V2_S_Weights.IMAGENET1K_V1))
def base_swin_v2_s(
    *, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> SwinTransformer:
    weights = Swin_V2_S_Weights.verify(weights)

    kwargs["model"] = "SwinTransformer"
    return own_swin_transformer(
        patch_size=[4, 4],
        embed_dim=96,
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        window_size=[8, 8],
        stochastic_depth_prob=0.3,
        weights=weights,
        progress=progress,
        block=SwinTransformerBlockV2,
        downsample_layer=PatchMergingV2,
        **kwargs,
    )


@register_model()
@handle_legacy_interface(weights=("pretrained", Swin_V2_S_Weights.IMAGENET1K_V1))
def own_swin_v2_s(
    *, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> CustomSwinTransformer:
    weights = Swin_V2_S_Weights.verify(weights)

    kwargs["model"] = "CustomSwinTransformer"
    return own_swin_transformer(
        patch_size=[4, 4],
        embed_dim=96,
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        window_size=[8, 8],
        stochastic_depth_prob=0.3,
        weights=weights,
        progress=progress,
        block=SwinTransformerBlockV2,
        downsample_layer=PatchMergingV2,
        **kwargs,
    )
