import logging
import math
from functools import partial

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import model_zoo

from .ViT.ViT_new import VisionTransformer, _cfg, _conv_filter
from .utils import UnitaryMatrixMultiplication, NonNegativeLinear

__all__ = [
    "CustomVisionTransformer",
    "own_vit_b_16",
]


default_cfgs = {
    "vit_base_patch16_224": _cfg(
        url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
    ),
}


class CustomVisionTransformer(VisionTransformer):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        norm_layer=nn.LayerNorm,
        gumbel_dim: int = 1,
        tau: float = 1,
    ):
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            num_classes=num_classes,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            norm_layer=norm_layer,
        )

        self.tau = tau
        self.gumbel_dim = gumbel_dim
        self.unitary_matrix = UnitaryMatrixMultiplication(embed_dim)

        self.head = NonNegativeLinear(embed_dim, num_classes)
        self.changed_layers = ["unitary_matrix", "head"]

    def apply_gumbel_softmax(self, x):
        shape = x.shape
        if self.gumbel_dim == -1:
            x = x.view(*shape[:-2], -1)

        if self.training and self.tau > 0:
            x = x * F.gumbel_softmax(x, tau=self.tau, dim=self.gumbel_dim)
        else:
            index = x.argmax(dim=self.gumbel_dim, keepdim=True)
            mask = torch.zeros_like(x).scatter_(self.gumbel_dim, index, 1.0)
            x = x * mask

        if self.gumbel_dim == -1:
            x = x.view(*shape)
        return x

    def forward(self, x: torch.Tensor, register_hook=False):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(
            B, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x, register_hook=register_hook)

        ########################################
        x = self.unitary_matrix(x.permute(0, 2, 1)).permute(0, 2, 1)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        x = v_positive - v_negative
        ########################################

        x = self.norm(x)
        # x = x[:, 0]
        x = torch.mean(x, dim=self.gumbel_dim)
        x = self.head(x)
        return x

    def explain(self, x, target=None, register_hook=False):
        _, _, h, w = x.shape

        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(
            B, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x, register_hook=register_hook)

        ########################################
        x = self.unitary_matrix(x.permute(0, 2, 1)).permute(0, 2, 1)

        v_positive = torch.relu(x)
        v_positive = self.apply_gumbel_softmax(v_positive)

        v_negative = torch.relu(torch.neg(x))
        v_negative = self.apply_gumbel_softmax(v_negative)

        heatmap = v_positive - v_negative
        ########################################

        heatmap = self.norm(heatmap)
        # x = x[:, 0]
        x = torch.mean(heatmap, dim=self.gumbel_dim)
        x = self.head(x)

        if target is None:
            target = torch.argmax(x, dim=1)
        weights = torch.abs(self.head.weight[target])
        heatmap = torch.einsum(
            "bi,bai->ba",
            weights,
            heatmap,
        )
        if self.head.bias is not None:
            heatmap += self.head.bias[target].view(-1, 1)

        # Assert model output matches heatmap values
        assert torch.allclose(
            x[0, target],
            heatmap.mean(dim=-1),
            rtol=1e-05,
            atol=1e-07,
        ), "The logit values of the model do not match the heatmap calculations"

        n_h = h // self.patch_embed.patch_size[0]
        n_w = w // self.patch_embed.patch_size[1]
        heatmap = heatmap[:, 1:].view(-1, n_h, n_w)

        # Normalize the heatmap for visualization
        heatmap /= (
            torch.flatten(heatmap, start_dim=1, end_dim=-1)
            .abs()
            .max(dim=-1)
            .values.view(-1, 1, 1)
        )  # changing value to [-1, 1]

        heatmap = torch.nn.functional.interpolate(
            heatmap.unsqueeze(1),
            (h, w),
        )
        return heatmap


_logger = logging.getLogger(__name__)


def load_pretrained(
    model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True
):
    if cfg is None:
        cfg = getattr(model, "default_cfg")
    if cfg is None or "url" not in cfg or not cfg["url"]:
        _logger.warning("Pretrained model URL is invalid, using random initialization.")
        return

    state_dict = model_zoo.load_url(cfg["url"], progress=False, map_location="cpu")

    if filter_fn is not None:
        state_dict = filter_fn(state_dict)

    if in_chans == 1:
        conv1_name = cfg["first_conv"]
        _logger.info(
            "Converting first conv (%s) pretrained weights from 3 to 1 channel"
            % conv1_name
        )
        conv1_weight = state_dict[conv1_name + ".weight"]
        # Some weights are in torch.half, ensure it's float for sum on CPU
        conv1_type = conv1_weight.dtype
        conv1_weight = conv1_weight.float()
        O, I, J, K = conv1_weight.shape
        if I > 3:
            assert conv1_weight.shape[1] % 3 == 0
            # For models with space2depth stems
            conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
            conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
        else:
            conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
        conv1_weight = conv1_weight.to(conv1_type)
        state_dict[conv1_name + ".weight"] = conv1_weight
    elif in_chans != 3:
        conv1_name = cfg["first_conv"]
        conv1_weight = state_dict[conv1_name + ".weight"]
        conv1_type = conv1_weight.dtype
        conv1_weight = conv1_weight.float()
        O, I, J, K = conv1_weight.shape
        if I != 3:
            _logger.warning(
                "Deleting first conv (%s) from pretrained weights." % conv1_name
            )
            del state_dict[conv1_name + ".weight"]
            strict = False
        else:
            # NOTE this strategy should be better than random init, but there could be other combinations of
            # the original RGB input layer weights that'd work better for specific cases.
            _logger.info(
                "Repeating first conv (%s) weights in channel dim." % conv1_name
            )
            repeat = int(math.ceil(in_chans / 3))
            conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
            conv1_weight *= 3 / float(in_chans)
            conv1_weight = conv1_weight.to(conv1_type)
            state_dict[conv1_name + ".weight"] = conv1_weight

    classifier_name = cfg["classifier"]
    if num_classes == 1000 and cfg["num_classes"] == 1001:
        # special case for imagenet trained models with extra background class in pretrained weights
        classifier_weight = state_dict[classifier_name + ".weight"]
        state_dict[classifier_name + ".weight"] = classifier_weight[1:]
        classifier_bias = state_dict[classifier_name + ".bias"]
        state_dict[classifier_name + ".bias"] = classifier_bias[1:]
    elif num_classes != cfg["num_classes"]:
        # completely discard fully connected for all other differences between pretrained and created model
        del state_dict[classifier_name + ".weight"]
        del state_dict[classifier_name + ".bias"]
        strict = False

    # model.load_state_dict(state_dict, strict=strict)

    if hasattr(model, "changed_layers"):
        keys_to_remove = [
            key
            for key in state_dict.keys()
            if any(key.startswith(layer) for layer in model.changed_layers)
        ]
        # Remove selected keys
        for key in keys_to_remove:
            state_dict.pop(key, None)
        print(f"Removing keys: {keys_to_remove} from pretrained weights")

    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    print(
        f"\033[0;1;33mMissing keys: {missing_keys}\033[0m" if missing_keys else "",
        (
            f"\033[0;1;33mUnexpected keys: {unexpected_keys}\033[0m"
            if unexpected_keys
            else ""
        ),
    )


def own_vit_b_16(pretrained=False, **kwargs):
    model = CustomVisionTransformer(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    model.default_cfg = default_cfgs["vit_base_patch16_224"]
    if pretrained:
        load_pretrained(
            model,
            num_classes=model.num_classes,
            in_chans=kwargs.get("in_chans", 3),
            filter_fn=_conv_filter,
        )
    return model
