import re
from typing import Any, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import DenseNet, DenseNet121_Weights
from torchvision.models._api import register_model, WeightsEnum
from torchvision.models._utils import handle_legacy_interface

from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomDenseNet",
    "base_densenet121",
    "own_densenet121",
]


class CustomDenseNet(DenseNet):

    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
        memory_efficient: bool = False,
        gumbel_dim: int = 1,
        tau: float = 1,
    ) -> None:

        super().__init__(
            growth_rate=growth_rate,
            block_config=block_config,
            num_init_features=num_init_features,
            bn_size=bn_size,
            drop_rate=drop_rate,
            num_classes=num_classes,
            memory_efficient=memory_efficient,
        )

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        in_features = self.classifier.in_features
        self.unitary_matrix = UnitaryMatrixMultiplication(in_features)
        self.classifier = NonNegativeLinear(in_features, num_classes)

        self.changed_layers = ["unitary_matrix", "classifier"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        # out = F.relu(features, inplace=True)

        ########################################
        x = features
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        out = v_positive - v_negative
        ########################################

        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out


def own_densenet(
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> CustomDenseNet | DenseNet:
    model = eval(kwargs.pop("model", "CustomDenseNet"))(
        growth_rate, block_config, num_init_features, **kwargs
    )
    if type(model) is DenseNet:
        model.changed_layers = ["classifier"]
        print("\033[0;1;34mUse base DenseNet\033[0m")

    if weights is not None:
        pattern = re.compile(
            r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
        )

        state_dict = weights.get_state_dict(progress=progress, check_hash=True)

        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in state_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                state_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]

        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )
    return model


@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def base_densenet121(
    *,
    weights: Optional[DenseNet121_Weights] = None,
    progress: bool = True,
    **kwargs: Any,
) -> DenseNet:
    weights = DenseNet121_Weights.verify(weights)

    kwargs["model"] = "DenseNet"
    return own_densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


@register_model()
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def own_densenet121(
    *,
    weights: Optional[DenseNet121_Weights] = None,
    progress: bool = True,
    **kwargs: Any,
) -> CustomDenseNet:
    weights = DenseNet121_Weights.verify(weights)

    kwargs["model"] = "CustomDenseNet"
    return own_densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
