from typing import Any, List, Optional, Callable

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchvision.models import ConvNeXt_Large_Weights, ConvNeXt
from torchvision.models._api import register_model, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.convnext import CNBlockConfig, ConvNeXt_Tiny_Weights

from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomConvNeXt",
    "base_convnext_tiny",
    "base_convnext_large",
    "own_convnext_tiny",
    "own_convnext_large",
]


class CustomConvNeXt(ConvNeXt):
    def __init__(
        self,
        block_setting: List[CNBlockConfig],
        stochastic_depth_prob: float = 0.0,
        layer_scale: float = 1e-6,
        num_classes: int = 1000,
        block: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        gumbel_dim: int = 1,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            block_setting=block_setting,
            stochastic_depth_prob=stochastic_depth_prob,
            layer_scale=layer_scale,
            num_classes=num_classes,
            block=block,
            norm_layer=norm_layer,
            **kwargs,
        )

        self.tau = kwargs.get("tau", 1.0)
        self.gumbel_dim = gumbel_dim
        in_features = self.classifier[-1].in_features
        self.unitary_matrix = UnitaryMatrixMultiplication(in_features)
        self.classifier[-1] = NonNegativeLinear(in_features, num_classes)

        self.changed_layers = ["unitary_matrix", "classifier"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)

        ########################################
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        x = self.avgpool(x)
        x = self.classifier(x)
        return x


def own_convnext(
    block_setting: List[CNBlockConfig],
    stochastic_depth_prob: float,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> CustomConvNeXt | ConvNeXt:
    model = eval(kwargs.pop("model", "CustomConvNeXt"))(
        block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs
    )
    if type(model) is ConvNeXt:
        model.changed_layers = ["classifier"]
        print("\033[0;1;34mUse base ConvNeXt\033[0m")

    if weights is not None:
        pretrained_dict = weights.get_state_dict(progress=progress, check_hash=True)

        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in pretrained_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                pretrained_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        missing_keys, unexpected_keys = model.load_state_dict(
            pretrained_dict, strict=False
        )
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )
    return model


@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def base_convnext_tiny(
    *,
    weights: Optional[ConvNeXt_Tiny_Weights] = None,
    progress: bool = True,
    **kwargs: Any,
) -> ConvNeXt:
    weights = ConvNeXt_Tiny_Weights.verify(weights)

    block_setting = [
        CNBlockConfig(96, 192, 3),
        CNBlockConfig(192, 384, 3),
        CNBlockConfig(384, 768, 9),
        CNBlockConfig(768, None, 3),
    ]
    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
    kwargs["model"] = "ConvNeXt"
    return own_convnext(
        block_setting, stochastic_depth_prob, weights, progress, **kwargs
    )


@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def own_convnext_tiny(
    *,
    weights: Optional[ConvNeXt_Tiny_Weights] = None,
    progress: bool = True,
    **kwargs: Any,
) -> CustomConvNeXt:
    weights = ConvNeXt_Tiny_Weights.verify(weights)

    block_setting = [
        CNBlockConfig(96, 192, 3),
        CNBlockConfig(192, 384, 3),
        CNBlockConfig(384, 768, 9),
        CNBlockConfig(768, None, 3),
    ]
    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
    kwargs["model"] = "CustomConvNeXt"
    return own_convnext(
        block_setting, stochastic_depth_prob, weights, progress, **kwargs
    )


@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def base_convnext_large(
    *,
    weights: Optional[ConvNeXt_Large_Weights] = None,
    progress: bool = True,
    **kwargs: Any,
) -> ConvNeXt:
    weights = ConvNeXt_Large_Weights.verify(weights)

    block_setting = [
        CNBlockConfig(192, 384, 3),
        CNBlockConfig(384, 768, 3),
        CNBlockConfig(768, 1536, 27),
        CNBlockConfig(1536, None, 3),
    ]
    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
    kwargs["model"] = "ConvNeXt"
    return own_convnext(
        block_setting, stochastic_depth_prob, weights, progress, **kwargs
    )


@register_model()
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def own_convnext_large(
    *,
    weights: Optional[ConvNeXt_Large_Weights] = None,
    progress: bool = True,
    **kwargs: Any,
) -> CustomConvNeXt:
    weights = ConvNeXt_Large_Weights.verify(weights)

    block_setting = [
        CNBlockConfig(192, 384, 3),
        CNBlockConfig(384, 768, 3),
        CNBlockConfig(768, 1536, 27),
        CNBlockConfig(1536, None, 3),
    ]
    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
    kwargs["model"] = "CustomConvNeXt"
    return own_convnext(
        block_setting, stochastic_depth_prob, weights, progress, **kwargs
    )
