from collections import OrderedDict
from functools import partial
from typing import Optional, Callable, List, Any

import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import VisionTransformer, WeightsEnum
from torchvision.models._api import register_model
from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param
from torchvision.models.vision_transformer import ConvStemConfig, ViT_B_16_Weights

from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomVisionTransformer",
    "base_vit_b_16",
    "own_vit_b_16",
]


class CustomVisionTransformer(VisionTransformer):
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
        gumbel_dim: int = 1,
        tau: float = 1,
    ):
        super().__init__(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=dropout,
            attention_dropout=attention_dropout,
            num_classes=num_classes,
            representation_size=representation_size,
            norm_layer=norm_layer,
            conv_stem_configs=conv_stem_configs,
        )

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        self.unitary_matrix = UnitaryMatrixMultiplication(hidden_dim)

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = NonNegativeLinear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = NonNegativeLinear(
                hidden_dim, representation_size
            )
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = NonNegativeLinear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)
        self.changed_layers = ["unitary_matrix", "heads"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        ########################################
        x = self.unitary_matrix(x.permute(0, 2, 1)).permute(0, 2, 1)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        # Classifier "token" as used by standard language architectures
        # x = x[:, 0]
        x = torch.mean(x, dim=self.gumbel_dim)

        x = self.heads(x)
        return x


def own_vision_transformer(
    patch_size: int,
    num_layers: int,
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> CustomVisionTransformer | VisionTransformer:
    if weights is not None:
        # _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
        _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
    image_size = kwargs.pop("image_size", 224)

    model = eval(kwargs.pop("model", "CustomVisionTransformer"))(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        num_heads=num_heads,
        hidden_dim=hidden_dim,
        mlp_dim=mlp_dim,
        **kwargs,
    )

    if type(model) is VisionTransformer:
        model.changed_layers = ["heads"]
        print("\033[0;1;34mUse base VisionTransformer\033[0m")

    if weights is not None:
        pretrained_dict = weights.get_state_dict(progress=progress, check_hash=True)

        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in pretrained_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                pretrained_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        missing_keys, unexpected_keys = model.load_state_dict(
            pretrained_dict, strict=False
        )
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )

    return model


@register_model()
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def base_vit_b_16(
    *, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
    weights = ViT_B_16_Weights.verify(weights)

    kwargs["model"] = "VisionTransformer"
    return own_vision_transformer(
        patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        weights=weights,
        progress=progress,
        **kwargs,
    )


@register_model()
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def own_vit_b_16(
    *, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any
) -> CustomVisionTransformer:
    weights = ViT_B_16_Weights.verify(weights)

    kwargs["model"] = "CustomVisionTransformer"
    return own_vision_transformer(
        patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        weights=weights,
        progress=progress,
        **kwargs,
    )
