from collections.abc import Callable

import torch
from torch import nn
from torch import IntTensor
from torch.nn import functional as F

from open_clip.transform import image_transform
from open_clip.model import CLIPTextCfg, CLIPVisionCfg

from PIL.Image import Image

from ..clip import CLIP

class TransformerBasicHead(nn.Module):
    def __init__(self, dim_in, num_classes):
        super().__init__()
        self.projection1 = nn.Linear(dim_in, 128, bias=True)
        self.projection2 = nn.Linear(128, 64, bias=True)
        self.projection3 = nn.Linear(64, num_classes, bias=True)
        self.bn1 = nn.BatchNorm1d(dim_in)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(64)

    def forward(self, x):
        x = self.projection1(x)
        x = F.relu(x, inplace=True)
        x = self.bn2(x)
        x = self.projection2(x)
        x = F.relu(x, inplace=True)
        x = self.bn3(x)
        x = self.projection3(x)
        return torch.sigmoid(x)


class Adapter(nn.Module):
    def __init__(self, c_in, reduction=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(c_in, c_in // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c_in // reduction, c_in, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.fc(x)
        return x


class InCTRL(nn.Module):
    def __init__(
        self,
        clip: CLIP,
    ):
        super().__init__()

        self.clip = clip
        self.adapter = Adapter(640, 4)
        self.diff_head = TransformerBasicHead(225, 1)
        self.diff_head_ref = TransformerBasicHead(640, 1)

    @property
    def dtype(self) -> torch.dtype:
        return self.clip.logit_scale.dtype

    @property
    def device(self) -> torch.device:
        return self.clip.logit_scale.device

    def encode_image(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        outputs = {}

        def get_feature_map(name: str) -> Callable:
            def hook(_model, inputs: tuple[torch.Tensor,], _outputs: torch.Tensor):
                del _model, _outputs
                outputs[name] = inputs[0].detach()
            return hook

        self.clip.visual.transformer.resblocks[6].register_forward_hook(get_feature_map("6"))
        self.clip.visual.transformer.resblocks[8].register_forward_hook(get_feature_map("8"))
        self.clip.visual.transformer.resblocks[10].register_forward_hook(get_feature_map("10"))

        image_embeds = self.clip.encode_image(batch)
        patch_embeds = torch.stack([
            outputs["6"].permute(1, 0, 2)[1:],
            outputs["8"].permute(1, 0, 2)[1:],
            outputs["10"].permute(1, 0, 2)[1:],
        ])  # LND -> NLD
        return image_embeds, patch_embeds

    def query_forward(self, image: torch.Tensor):
        b = image.size(0)
        token, fp_list = self.encode_image(image)
        token_ad = self.adapter(token)
        fp_list = fp_list.reshape(b, 3, 225, -1)
        return token, token_ad, fp_list

    def fewshot_forward(self, image: torch.Tensor, batch_size: int):
        shot = image.size(0)
        image = image.repeat_interleave(batch_size, 0)

        token_n, fp_list_n = self.encode_image(image)

        fp_list_n = fp_list_n.reshape(batch_size, 3, 225 * shot, -1)
        token_n = token_n.reshape(batch_size, shot, -1)

        token_n = self.adapter(token_n)
        token_n = torch.mean(token_n, dim=1)
        return token_n, fp_list_n

    def encode_text(self, text: IntTensor | list[IntTensor] | list[str] | list[list[str]]) -> torch.Tensor:
        return self.clip.encode_text(text)

    def score(self, token, token_ad, fp_list, token_n, fp_list_n, normal_text_features, anomal_text_features):
        b = token_ad.size(0)

        token_n = token_n[:b]
        fp_list_n = fp_list_n[:b]

        token_ref = token_n - token_ad

        pos_features = normal_text_features
        neg_features = anomal_text_features
        pos_features = pos_features / pos_features.norm(dim=-1, keepdim=True)
        neg_features = neg_features / neg_features.norm(dim=-1, keepdim=True)
        pos_features = torch.mean(pos_features, dim=0, keepdim=True)
        neg_features = torch.mean(neg_features, dim=0, keepdim=True)
        pos_features = pos_features / pos_features.norm(dim=-1, keepdim=True)
        neg_features = neg_features / neg_features.norm(dim=-1, keepdim=True)
        text_features = torch.cat([pos_features, neg_features], dim=0)

        text_score = []
        max_diff_score = []
        patch_ref_map = []

        for i in range(b):
            Fp = fp_list[i, :, :, :]
            Fp_n = fp_list_n[i, :, :, :]

            Fp_map = list()
            for j in range(len(Fp)):
                tmp_x = Fp[j, :, :]
                tmp_n = Fp_n[j, :, :]

                am_fp = []
                for k in range(len(tmp_x)):
                    tmp = tmp_x[k].unsqueeze(0)

                    tmp_n = tmp_n / tmp_n.norm(dim=-1, keepdim=True)
                    tmp = tmp / tmp.norm(dim=-1, keepdim=True)
                    s = (0.5 * (1 - (tmp @ tmp_n.T))).min(dim=1).values
                    am_fp.append(s)

                am_fp = torch.stack(am_fp)
                Fp_map.append(am_fp)

            Fp_map = torch.stack(Fp_map)
            Fp_map = torch.mean(Fp_map.squeeze(2), dim=0)
            patch_ref_map.append(Fp_map)
            score = Fp_map.max(dim=0).values
            max_diff_score.append(score)

            # zero shot
            image_feature = token[i]
            image_feature = image_feature.unsqueeze(0)
            image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

            score = (100 * image_feature @ text_features.T).softmax(dim=-1)
            tmp = score[0, 1]
            text_score.append(tmp)

        text_score = torch.stack(text_score).unsqueeze(1)
        img_ref_score = self.diff_head_ref(token_ref)
        patch_ref_map = torch.stack(patch_ref_map)
        holistic_map = text_score + img_ref_score + patch_ref_map
        hl_score = self.diff_head(holistic_map)

        hl_score = hl_score.squeeze(1)
        fg_score = torch.stack(max_diff_score)
        final_score = (hl_score + fg_score) / 2

        return final_score

    def forward(
        self,
        image: torch.Tensor,
        normal_prompts,
        anomaly_prompts,
        reference_image: torch.Tensor,
    ):  # type: ignore
        b = image.size(0)
        shot = reference_image.size(0)
        reference_image = reference_image.repeat_interleave(b, 0)

        token, fp_list = self.encode_image(image)
        token_n, fp_list_n = self.encode_image(reference_image)

        fp_list = fp_list.reshape(b, 3, 225, -1)
        fp_list_n = fp_list_n.reshape(b, 3, 225 * shot, -1)
        token_n = token_n.reshape(b, shot, -1)

        token_ad = self.adapter(token)
        token_n = self.adapter(token_n)
        token_n = torch.mean(token_n, dim=1)

        token_ref = token_n - token_ad

        text_score = []
        max_diff_score = []
        patch_ref_map = []

        for i in range(b):
            Fp = fp_list[i, :, :, :]
            Fp_n = fp_list_n[i, :, :, :]

            Fp_map = list()
            for j in range(len(Fp)):
                tmp_x = Fp[j, :, :]
                tmp_n = Fp_n[j, :, :]

                am_fp = []
                for k in range(len(tmp_x)):
                    tmp = tmp_x[k].unsqueeze(0)

                    tmp_n = tmp_n / tmp_n.norm(dim=-1, keepdim=True)
                    tmp = tmp / tmp.norm(dim=-1, keepdim=True)
                    s = (0.5 * (1 - (tmp @ tmp_n.T))).min(dim=1).values
                    am_fp.append(s)

                am_fp = torch.stack(am_fp)
                Fp_map.append(am_fp)

            Fp_map = torch.stack(Fp_map)
            Fp_map = torch.mean(Fp_map.squeeze(2), dim=0)
            patch_ref_map.append(Fp_map)
            score = Fp_map.max(dim=0).values
            max_diff_score.append(score)

            # zero shot
            image_feature = token[i]
            image_feature = image_feature.unsqueeze(0)
            image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

            pos_features = self.encode_text(normal_prompts)
            neg_features = self.encode_text(anomaly_prompts)
            pos_features = pos_features / pos_features.norm(dim=-1, keepdim=True)
            neg_features = neg_features / neg_features.norm(dim=-1, keepdim=True)
            pos_features = torch.mean(pos_features, dim=0, keepdim=True)
            neg_features = torch.mean(neg_features, dim=0, keepdim=True)
            pos_features = pos_features / pos_features.norm(dim=-1, keepdim=True)
            neg_features = neg_features / neg_features.norm(dim=-1, keepdim=True)
            text_features = torch.cat([pos_features, neg_features], dim=0)
            score = (100 * image_feature @ text_features.T).softmax(dim=-1)
            tmp = score[0, 1]
            text_score.append(tmp)

        text_score = torch.stack(text_score).unsqueeze(1)
        img_ref_score = self.diff_head_ref(token_ref)
        patch_ref_map = torch.stack(patch_ref_map)
        holistic_map = text_score + img_ref_score + patch_ref_map
        hl_score = self.diff_head(holistic_map)

        hl_score = hl_score.squeeze(1)
        fg_score = torch.stack(max_diff_score)
        final_score = (hl_score + fg_score) / 2

        return final_score


def load_inctrl(
    checkpoint: str,
    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
) -> tuple[InCTRL, Callable[[Image], torch.Tensor]]:
    transform = image_transform(image_size=(240, 240), is_train=False)

    clip = CLIP(
        embed_dim=640,
        vision_cfg=CLIPVisionCfg(image_size=240, layers=12, width=896, patch_size=16),
        text_cfg=CLIPTextCfg(context_length=77, vocab_size=49408, width=640, heads=10, layers=12),
    )

    model = InCTRL(clip)
    state_dict = torch.load(checkpoint, map_location=device, weights_only=True)
    state_dict = {
        k if k.startswith("adapter.") or k.startswith("diff_head.") or k.startswith("diff_head_ref.") else f"clip.{k}": v
        for k, v in state_dict.items()
    }
    model.load_state_dict(state_dict, strict=False)
    model = model.eval().to(device)

    return model, transform  # type: ignore
