from typing import Tuple, List, Callable, Optional, Any

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchvision.models import MaxVit, WeightsEnum, MaxVit_T_Weights
from torchvision.models._api import register_model
from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface

from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomMaxVit",
    "base_maxvit_t",
    "own_maxvit_t",
]


class CustomMaxVit(MaxVit):
    def __init__(
        self,
        input_size: Tuple[int, int],
        stem_channels: int,
        partition_size: int,
        block_channels: List[int],
        block_layers: List[int],
        head_dim: int,
        stochastic_depth_prob: float,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Callable[..., nn.Module] = nn.GELU,
        squeeze_ratio: float = 0.25,
        expansion_ratio: float = 4,
        mlp_ratio: int = 4,
        mlp_dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        gumbel_dim: int = 1,
        tau: float = 1,
    ) -> None:
        super().__init__(
            input_size=input_size,
            stem_channels=stem_channels,
            partition_size=partition_size,
            block_channels=block_channels,
            block_layers=block_layers,
            head_dim=head_dim,
            stochastic_depth_prob=stochastic_depth_prob,
            norm_layer=norm_layer,
            activation_layer=activation_layer,
            squeeze_ratio=squeeze_ratio,
            expansion_ratio=expansion_ratio,
            mlp_ratio=mlp_ratio,
            mlp_dropout=mlp_dropout,
            attention_dropout=attention_dropout,
            num_classes=num_classes,
        )

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        self.unitary_matrix = UnitaryMatrixMultiplication(block_channels[-1])

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.LayerNorm(block_channels[-1]),
            NonNegativeLinear(block_channels[-1], block_channels[-1]),
            nn.Tanh(),
            NonNegativeLinear(block_channels[-1], num_classes, bias=False),
        )

        self.changed_layers = ["unitary_matrix", "classifier"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        for block in self.blocks:
            x = block(x)

        ########################################
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        x = self.classifier(x)
        return x


def own_maxvit(
    stem_channels: int,
    block_channels: List[int],
    block_layers: List[int],
    stochastic_depth_prob: float,
    partition_size: int,
    head_dim: int,
    weights: Optional[WeightsEnum] = None,
    progress: bool = False,
    **kwargs: Any,
) -> CustomMaxVit | MaxVit:

    if weights is not None:
        # _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
        _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])

    input_size = kwargs.pop("input_size", (224, 224))

    model = eval(kwargs.pop("model", "CustomMaxVit"))(
        stem_channels=stem_channels,
        block_channels=block_channels,
        block_layers=block_layers,
        stochastic_depth_prob=stochastic_depth_prob,
        head_dim=head_dim,
        partition_size=partition_size,
        input_size=input_size,
        **kwargs,
    )

    if type(model) is MaxVit:
        model.changed_layers = ["classifier"]
        print("\033[0;1;34mUse base VisionTransformer\033[0m")

    if weights is not None:
        pretrained_dict = weights.get_state_dict(progress=progress, check_hash=True)

        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in pretrained_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                pretrained_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        missing_keys, unexpected_keys = model.load_state_dict(
            pretrained_dict, strict=False
        )
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )

    return model


@register_model()
@handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
def base_maxvit_t(
    *, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any
) -> MaxVit:
    weights = MaxVit_T_Weights.verify(weights)

    kwargs["model"] = "MaxVit"
    return own_maxvit(
        stem_channels=64,
        block_channels=[64, 128, 256, 512],
        block_layers=[2, 2, 5, 2],
        head_dim=32,
        stochastic_depth_prob=0.2,
        partition_size=7,
        weights=weights,
        progress=progress,
        **kwargs,
    )


@register_model()
@handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
def own_maxvit_t(
    *, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any
) -> CustomMaxVit:
    weights = MaxVit_T_Weights.verify(weights)

    kwargs["model"] = "CustomMaxVit"
    return own_maxvit(
        stem_channels=64,
        block_channels=[64, 128, 256, 512],
        block_layers=[2, 2, 5, 2],
        head_dim=32,
        stochastic_depth_prob=0.2,
        partition_size=7,
        weights=weights,
        progress=progress,
        **kwargs,
    )
