from typing import Any, List, Optional, Type, Union, Callable

import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import ResNet
from torchvision.models._api import register_model, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.resnet import (
    BasicBlock,
    Bottleneck,
    ResNet34_Weights,
    ResNet50_Weights,
)

from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomResNet",
    "base_resnet34",
    "base_resnet50",
    "own_resnet34",
    "own_resnet50",
]


class CustomResNet(ResNet):

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        gumbel_dim: int = 1,
        tau: float = 1,
    ) -> None:
        super().__init__(
            block=block,
            layers=layers,
            num_classes=num_classes,
            zero_init_residual=zero_init_residual,
            groups=groups,
            width_per_group=width_per_group,
            replace_stride_with_dilation=replace_stride_with_dilation,
            norm_layer=norm_layer,
        )

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        in_features = 512 * block.expansion
        self.unitary_matrix = UnitaryMatrixMultiplication(in_features)
        self.fc = NonNegativeLinear(in_features, num_classes)

        self.changed_layers = ["unitary_matrix", "fc"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        ########################################
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def own_resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> CustomResNet | ResNet:
    model = eval(kwargs.pop("model", "CustomResNet"))(block, layers, **kwargs)
    if type(model) is ResNet:
        model.changed_layers = ["fc"]
        print("\033[0;1;34mUse base ResNet\033[0m")

    if weights is not None:
        pretrained_dict = weights.get_state_dict(progress=progress, check_hash=True)

        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in pretrained_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                pretrained_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        missing_keys, unexpected_keys = model.load_state_dict(
            pretrained_dict, strict=False
        )
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )
    return model


@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def base_resnet34(
    *, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    weights = ResNet34_Weights.verify(weights)

    kwargs["model"] = "ResNet"
    return own_resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)


@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def own_resnet34(
    *, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any
) -> CustomResNet:
    weights = ResNet34_Weights.verify(weights)

    kwargs["model"] = "CustomResNet"
    return own_resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)


@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def base_resnet50(
    *, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    weights = ResNet50_Weights.verify(weights)

    kwargs["model"] = "ResNet"
    return own_resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


@register_model()
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def own_resnet50(
    *, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any
) -> CustomResNet:
    weights = ResNet50_Weights.verify(weights)

    kwargs["model"] = "CustomResNet"
    return own_resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
