import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import PatchEmbed
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection


class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # self.config.num_attention_heads 24
        # self.config.attention_head_dim 64
        self.view_to_one = nn.Linear(4, 1)  #

        # 24*64=1536
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(1536, nhead=24), num_layers=1)
        self.pos_embed = PatchEmbed(
            height=32,
            width=32,
            patch_size=2,
            in_channels=16,
            embed_dim=1536,
            pos_embed_max_size=96,  # hard-code for now.
        )

    def forward(self, x):  # x: torch.Size([64, 4, 16, 32, 32])   # [bs, view, dim, h, w]
        # print("x.shape", x.shape)
        bs = x.shape[0]
        x = x.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x = x.reshape(-1, 4)  # [bs*16*32*32, 4]
        x = self.view_to_one(x)  # [bs*16*32*32, 1]
        x = x.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]  # [bs, dim, h, w]
        x, _ = self.pos_embed(x)  # x.shape torch.Size([64, 256, 1536])
        # print("x.shape", x.shape)
        x = self.encoder(x)
        # print("x.shape", x.shape)
        # x.shape torch.Size([64, 256, 1536])
        x = x.mean(dim=1)  # [bs, 1536]
        return x


class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)  #
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 64, 3, 2, 1),  # [64, 16, 16]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(64, 128, 3, 2, 1),  # [128, 8, 8]
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(128, 256, 3, 2, 1),  # [256, 4, 4]
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.AdaptiveAvgPool2d(1),  # [256, 1, 1]
            nn.Flatten(),  # [256]
            nn.Dropout(0.2),
        )

    def forward(self, x):
        bs = x.shape[0]
        x = x.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x = x.reshape(-1, 4)  # [bs*16*32*32, 4]
        x = self.view_to_one(x)  # [bs*16*32*32, 1]
        x = x.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]
        return self.encoder(x)  # [bs, 256]


class PairwiseComparator_CNN_0(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = CNNEncoder()
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x1, x2):  # x1, x2: [B, 4, 16, 32, 32]
        f1 = self.encoder(x1)  # [B, 256]
        f2 = self.encoder(x2)
        # print(f1.shape) # torch.Size([32, 256])
        # print(f2.shape) # torch.Size([32, 256])
        diff = f2 - f1
        out = self.classifier(diff).squeeze(1)  # [B]
        return out  # logits


class PairwiseComparator_Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        # self.encoder = CNNEncoder()
        self.encoder = TransformerEncoder()
        self.classifier = nn.Sequential(
            nn.Linear(1536, 1536),
            # nn.ReLU(),
            nn.GELU(),
            # nn.Dropout(0.3),
            nn.Linear(1536, 1),
            # nn.Sigmoid(),
        )

    def forward(self, x1, x2):  # x1, x2: [B, 4, 16, 32, 32]
        f1 = self.encoder(x1)  # [B, 1536]
        f2 = self.encoder(x2)
        diff = f2 - f1
        out = self.classifier(diff).squeeze(1)  # [B]
        return out  # raw logits


class PairwiseComparator_no_encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 64), nn.GELU(), nn.Linear(64, 1))  # 把32x32 spatial平均成1x1  # [B, 16]  # 输出logit

    def forward(self, x1, x2):
        diff = x2 - x1  # [B, 4, 16, 32, 32]
        diff = diff.mean(dim=1)  # mean over n_view -> [B, 16, 32, 32]
        out = self.classifier(diff)  # [B, 1]
        return out.squeeze(1)  # [B]


class PairwiseComparator_C0L2(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 16), nn.GELU(), nn.Linear(16, 1))  # [B, 64, 1, 1]  # [B, 64]

    def forward(self, x1, x2):
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        # feat1 = self.encoder(x1)  # [B, 64, 8, 8]
        # feat2 = self.encoder(x2)  # [B, 64, 8, 8]
        # diff = feat1 - feat2      # 差异特征
        diff = x1 - x2
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_C1L2(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            # nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            # nn.GELU(),
            # nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # [B, 64, 1, 1]
            nn.Flatten(),  # [B, 64]
            # nn.Linear(64, 64),
            # nn.GELU(),
            # nn.Linear(64, 1)
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 1),
            # nn.Linear(16, 16),
            # nn.GELU(),
            # nn.Linear(16, 1)
        )

    def forward(self, x1, x2):
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 64, 8, 8]
        feat2 = self.encoder(x2)  # [B, 64, 8, 8]
        diff = feat1 - feat2  # 差异特征
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_C2L2(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            nn.GELU(),
            nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, 64), nn.GELU(), nn.Linear(64, 1))  # [B, 64, 1, 1]  # [B, 64]

    def forward(self, x1, x2):
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 64, 8, 8]
        feat2 = self.encoder(x2)  # [B, 64, 8, 8]
        diff = feat1 - feat2  # 差异特征
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_C1L3(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            # nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            # nn.GELU(),
            # nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 1))  # [B, 64, 1, 1]  # [B, 64]

    def forward(self, x1, x2):
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 64, 8, 8]
        feat2 = self.encoder(x2)  # [B, 64, 8, 8]
        diff = feat1 - feat2  # 差异特征
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_C3L2(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            nn.GELU(),
            nn.Dropout2d(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # [B,128,4,4]
            nn.GELU(),
            nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 1))  # [B, 64, 1, 1]  # [B, 64]

    def forward(self, x1, x2):
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 64, 8, 8]
        feat2 = self.encoder(x2)  # [B, 64, 8, 8]
        diff = feat1 - feat2  # 差异特征
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_C1L2_prompt(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            # nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            # nn.GELU(),
            # nn.Dropout2d(0.2),
        )
        self.prompt_encoder = nn.Sequential(
            nn.Linear(512, 32),
            nn.GELU(),
            nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),  # [B, 64, 1, 1]
            # nn.Flatten(),             # [B, 64]
            # nn.Linear(64, 32),
            # nn.GELU(),
            # nn.Linear(32, 1)
            # nn.Linear(32, 32),
            # nn.GELU(),
            # nn.Linear(32, 1)
            nn.Linear(32 + 512, 32 + 512),
            nn.GELU(),
            nn.Linear(32 + 512, 1),
            # nn.Linear(16, 16),
            # nn.GELU(),
            # nn.Linear(16, 1)
        )

    def forward(self, x1, x2, prompt):  # prompt: [B, 512]
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 32,16,16]
        feat2 = self.encoder(x2)  # [B, 32,16,16]
        diff = feat1 - feat2  # 差异特征
        diff = F.adaptive_avg_pool2d(diff, 1)  # [B, 32, 1, 1]
        diff = diff.flatten(1)  # [B, 32]
        # prompt = self.prompt_encoder(prompt)  # [B, 32]
        diff = torch.cat([diff, prompt], dim=1)  # [B, 32+32]
        # diff = diff + prompt
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_C1L3_prompt(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            # nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            # nn.GELU(),
            # nn.Dropout2d(0.2),
        )
        self.prompt_encoder = nn.Sequential(
            nn.Linear(512, 32),
            nn.GELU(),
            nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),  # [B, 64, 1, 1]
            # nn.Flatten(),             # [B, 64]
            # nn.Linear(64, 32),
            # nn.GELU(),
            # nn.Linear(32, 1)
            nn.Linear(32 + 512, 32 + 512),
            nn.GELU(),
            nn.Linear(32 + 512, 32 + 512),
            nn.GELU(),
            nn.Linear(32 + 512, 1),
            # nn.Linear(16, 16),
            # nn.GELU(),
            # nn.Linear(16, 1)
        )

    def forward(self, x1, x2, prompt):  # prompt: [B, 512]
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 32,16,16]
        feat2 = self.encoder(x2)  # [B, 32,16,16]
        diff = feat1 - feat2  # 差异特征
        diff = F.adaptive_avg_pool2d(diff, 1)  # [B, 32, 1, 1]
        diff = diff.flatten(1)  # [B, 32]
        # prompt = self.prompt_encoder(prompt)  # [B, 32]
        diff = torch.cat([diff, prompt], dim=1)  # [B, 32+32]
        # diff = diff + prompt
        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


class PairwiseComparator_CNN_prompt(nn.Module):
    def __init__(self):
        super().__init__()
        self.view_to_one = nn.Linear(4, 1)
        self.encoder = nn.Sequential(
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # [B,32,16,16]
            nn.GELU(),
            nn.Dropout2d(0.2),
            # nn.Conv2d(32, 64, 3, stride=2, padding=1),  # [B,64,8,8]
            # nn.GELU(),
            # nn.Dropout2d(0.2),
        )
        self.classifier = nn.Sequential(
            # nn.Linear(64, 64),
            # nn.GELU(),
            # nn.Linear(64, 1)
            nn.Linear(32, 32),
            nn.GELU(),
            nn.Linear(32, 1),
            # nn.Linear(16, 16),
            # nn.GELU(),
            # nn.Linear(16, 1)
        )
        self.fusion = nn.Sequential(
            nn.Linear(576, 576),
            nn.GELU(),
        )

    def forward(self, x1, x2, prompt):  # prompt: [B, 512]
        # x1 = x1.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        # x2 = x2.mean(dim=1)  # view pooling: [B, 16, 32, 32]
        bs = x1.shape[0]
        x1 = x1.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x1 = x1.reshape(-1, 4)  # [bs*16*32*32, 4]
        x1 = self.view_to_one(x1)  # [bs*16*32*32, 1]
        x1 = x1.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        x2 = x2.permute(0, 2, 3, 4, 1)  # [bs, 16, 32, 32, 4]
        x2 = x2.reshape(-1, 4)  # [bs*16*32*32, 4]
        x2 = self.view_to_one(x2)  # [bs*16*32*32, 1]
        x2 = x2.reshape(bs, 16, 32, 32)  # [bs, 16, 32, 32]

        feat1 = self.encoder(x1)  # [B, 64, 8, 8]
        feat2 = self.encoder(x2)  # [B, 64, 8, 8]
        diff = feat2 - feat1  # 差异特征
        diff = F.adaptive_avg_pool2d(diff, 1)  # [B, 64, 1, 1]
        diff = diff.flatten(1)  # [B, 64]
        diff = torch.cat([diff, prompt], dim=1)  # [B, 64+512]
        diff = self.fusion(diff)  # [B, 576]

        out = self.classifier(diff)
        return out.squeeze(1)  # [B]


def emb_prompt(prompt_path):

    with open(prompt_path, "r") as f:
        prompts = [line.strip() for line in f.readlines() if line.strip() != ""]

    prompts = prompts[:100]

    clip_name: str = "openai/clip-vit-base-patch32"

    tokenizer = CLIPTokenizer.from_pretrained(clip_name)
    text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).cuda().eval()

    text_inputs = tokenizer(
        prompts,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids.cuda()
    text_embeds = text_encoder(text_input_ids).text_embeds.float()  # (N, D)
    text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

    print(text_embeds.shape)  # torch.Size([100, 512])

    return text_embeds


if __name__ == "__main__":
    import os

    os.environ["CUDA_VISIBLE_DEVICES"] = "5"

    model = PairwiseComparator_C1L2()
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters: {total_params:,}")
