from typing import Any, List, Optional, Type, Union, Callable

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.hub import load_state_dict_from_url

from .resnet import ResNet, BasicBlock, Bottleneck
from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomResNet",
    "own_resnet50",
]

model_urls = {
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
}


class CustomResNet(ResNet):

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        gumbel_dim: int = 1,
        tau: float = 1,
    ) -> None:
        super().__init__(
            block=block,
            layers=layers,
            num_classes=num_classes,
            zero_init_residual=zero_init_residual,
            groups=groups,
            width_per_group=width_per_group,
            replace_stride_with_dilation=replace_stride_with_dilation,
            norm_layer=norm_layer,
        )

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        in_features = 512 * block.expansion
        self.unitary_matrix = UnitaryMatrixMultiplication(in_features)
        self.fc = NonNegativeLinear(in_features, num_classes)

        self.changed_layers = ["unitary_matrix", "fc"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        ########################################
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def explain(self, x, target=None):
        _, _, h, w = x.shape

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        ########################################
        x = self.unitary_matrix(x)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        heatmap = v_positive - v_negative
        ########################################

        x = self.avgpool(heatmap)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        if target is None:
            target = torch.argmax(x, dim=1)
        weights = torch.abs(self.fc.weight[target])
        heatmap = torch.einsum(
            "bi,biwh->bwh",
            weights,
            heatmap,
        )
        if self.fc.bias is not None:
            heatmap += self.fc.bias[target].view(-1, 1, 1)

        # Assert model output matches heatmap values
        assert torch.allclose(
            x[0, target],
            heatmap.sum(dim=[-2, -1]).div(np.prod(heatmap.shape[-2:])),
            rtol=1e-05,
            atol=1e-07,
        ), "The logit values of the model do not match the heatmap calculations"

        # Normalize the heatmap for visualization
        heatmap /= (
            torch.flatten(heatmap, start_dim=1, end_dim=-1)
            .abs()
            .max(dim=-1)
            .values.view(-1, 1, 1)
        )  # changing value to [-1, 1]

        heatmap = torch.nn.functional.interpolate(
            heatmap.unsqueeze(1),
            (h, w),
        )
        return heatmap


def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any,
) -> ResNet:
    model = CustomResNet(block, layers, **kwargs)
    if pretrained:
        pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        if hasattr(model, "changed_layers"):
            keys_to_remove = [
                key
                for key in pretrained_dict.keys()
                if any(key.startswith(layer) for layer in model.changed_layers)
            ]
            # Remove selected keys
            for key in keys_to_remove:
                pretrained_dict.pop(key, None)
            print(f"Removing keys: {keys_to_remove} from pretrained weights")

        missing_keys, unexpected_keys = model.load_state_dict(
            pretrained_dict, strict=False
        )
        print(
            f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
            (
                f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
                if unexpected_keys
                else ""
            ),
        )
    return model


def own_resnet50(
    pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> ResNet:
    return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
