import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

class PrototypeSet(nn.Module):
    def __init__(self, token_dim=2048, num_prototypes=16):
        super().__init__()
        self.token_dim = token_dim
        self.num_prototypes = num_prototypes

        self.global_prototype = nn.Parameter(torch.ones(token_dim))
        self.local_prototypes = nn.Parameter(torch.randn(num_prototypes, token_dim))

        self.prototype_interaction = nn.Sequential(
            nn.Linear(token_dim, token_dim // 4),
            nn.GELU(),
            nn.Linear(token_dim // 4, token_dim)
        )

        self.depthwise_conv = nn.Conv1d(
            in_channels=token_dim,
            out_channels=token_dim,
            kernel_size=1,
            groups=token_dim,
            bias=False
        )

        self.beta_mlp = nn.Sequential(
            nn.Linear(token_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 3)
        )

        self.scale_mlp = ScaleMLP(token_dim)

        self.spatial_adaptor = nn.Sequential(
            nn.Conv2d(token_dim, token_dim // 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(token_dim // 16, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, tokens, cond=None):

        H, N, C = tokens.shape[1:]
        assert C == self.token_dim

        if cond is not None:
            gate = torch.sigmoid(torch.dot(self.global_prototype, cond) / C)
            global_proto = gate * self.global_prototype + (1 - gate) * cond
        else:
            global_proto = self.global_prototype

        global_scaled = tokens * global_proto.view(1, 1, 1, -1)

        tokens_reshape = global_scaled.view(-1, C, 1)
        keys = self.depthwise_conv(tokens_reshape).squeeze(-1)

        interacted_prototypes = self.local_prototypes + self.prototype_interaction(self.local_prototypes)

        attn_scores = torch.matmul(keys, interacted_prototypes.t()) / (C ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        local_bias = torch.matmul(attn_weights, interacted_prototypes).view(1, H, N, C)

        spatial_weights = self.spatial_adaptor(tokens.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

        modulated_tokens = spatial_weights * global_scaled + (1 - spatial_weights) * local_bias

        beta = self.beta_mlp(global_proto).view(1, 1, 1, 3)
        scale = self.scale_mlp(global_proto).view(1, 1, 1, 2)

        return modulated_tokens, beta, scale




# MLP to map global prototype to scale parameters
class ScaleMLP(nn.Module):
    def __init__(self, token_dim):
        super().__init__()
        self.fc1 = nn.Linear(token_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x1 = F.softplus(x[0])  # Ensure positive
        x2 = torch.sigmoid(x[1])  # Constrain to [0,1]
        return torch.stack([x1, x2])

def get_fixed_As(num_prototypes):

    num_types = 6  # blue / green / yellow / cyan / red / dark
    num_each = num_prototypes // num_types

    def lin_range(low, high, n):
        return torch.linspace(low, high, n)

    A_R_blue = lin_range(0.3, 0.6, num_each)
    A_G_blue = lin_range(0.5, 0.8, num_each)
    A_B_blue = lin_range(0.8, 1.0, num_each)

    A_R_green = lin_range(0.3, 0.5, num_each)
    A_G_green = lin_range(0.7, 1.0, num_each)
    A_B_green = lin_range(0.4, 0.7, num_each)

    A_R_yellow = lin_range(0.5, 0.9, num_each)
    A_G_yellow = lin_range(0.6, 0.9, num_each)
    A_B_yellow = lin_range(0.3, 0.6, num_each)

    A_R_cyan = lin_range(0.4, 0.7, num_each)
    A_G_cyan = lin_range(0.6, 0.9, num_each)
    A_B_cyan = lin_range(0.5, 0.8, num_each)

    A_R_red = lin_range(0.7, 1.0, num_each)
    A_G_red = lin_range(0.2, 0.5, num_each)
    A_B_red = lin_range(0.2, 0.4, num_each)

    A_R_dark = lin_range(0.05, 0.3, num_each)
    A_G_dark = lin_range(0.05, 0.3, num_each)
    A_B_dark = lin_range(0.05, 0.3, num_each)


    A_R = torch.cat([A_R_blue, A_R_green, A_R_yellow, A_R_cyan, A_R_red, A_R_dark], dim=0)
    A_G = torch.cat([A_G_blue, A_G_green, A_G_yellow, A_G_cyan, A_G_red, A_G_dark], dim=0)
    A_B = torch.cat([A_B_blue, A_B_green, A_B_yellow, A_B_cyan, A_B_red, A_B_dark], dim=0)

    A_vals = torch.stack([A_R, A_G, A_B], dim=1)  # [num_prototypes, 3]
    return A_vals



class LightAEstimator(nn.Module):
    def __init__(self):
        super(LightAEstimator, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # [B, 16, H/2, W/2]
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2), # [B, 32, H/4, W/4]
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # [B, 64, H/8, W/8]
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.AdaptiveAvgPool2d((1, 1))  # [B, 64, 1, 1]
        )

        self.fc = nn.Linear(64, 3) 

    def forward(self, x):
        x = self.features(x)      # [B, 64, 1, 1]
        x = x.view(x.size(0), -1) # [B, 64]
        x = self.fc(x)            # [B, 3]
        x = torch.sigmoid(x)      
        return x


class TokenPrototypeModulatorGAT(nn.Module):
    def __init__(self, token_dim=2048, num_prototypes=16):
        super().__init__()
        self.token_dim = token_dim
        self.num_prototypes = num_prototypes

        self.prototype_sets = nn.ModuleList([
            PrototypeSet(token_dim, num_prototypes)
            for _ in range(num_prototypes)
        ])

        A_vals = get_fixed_As(num_prototypes)
        self.A_values = nn.Parameter(A_vals)

        self.gat_new = SimpleGATLayer(6, 128)

        self.proj_mlp = nn.Sequential(
            nn.Linear(128, token_dim),
            nn.ReLU(),
            nn.Linear(token_dim, token_dim)
        )

    def forward(self, aggregated_tokens_list, A_pred):
        num_layers = len(aggregated_tokens_list)
        P = self.num_prototypes

        dist_matrix = torch.cdist(self.A_values, self.A_values, p=2)
        adj = (dist_matrix < 0.3).float().to(A_pred.device)

        gat_input = torch.cat([self.A_values, A_pred.expand(P, -1)], dim=1)  # [P, 6]
        proto_feat, attn = self.gat_new(gat_input, adj)

        refined_prototypes = self.proj_mlp(proto_feat)  # [P, token_dim]

        similarities = torch.exp(-10 * torch.norm(A_pred - torch.sigmoid(self.A_values), dim=1))  # [P]
        topk_vals, topk_indices = torch.topk(similarities, 4)
        weights = topk_vals / topk_vals.sum()

        modulated_tokens_list = [torch.zeros_like(tokens) for tokens in aggregated_tokens_list]
        beta_stack = torch.zeros(1, 1, 1, 3).to(A_pred.device)
        scale_stack = torch.zeros(1, 1, 1, 2).to(A_pred.device)

        for l in range(num_layers):
            layer_tokens = aggregated_tokens_list[l]
            for i, idx in enumerate(topk_indices):
                ps = self.prototype_sets[idx]
                w = weights[i].item()
                cond = refined_prototypes[idx] 

                mod_tokens, beta, scale = ps(layer_tokens, cond=cond)

                modulated_tokens_list[l] += w * mod_tokens
                if l == 0:
                    beta_stack += w * beta
                    scale_stack += w * scale

        A_val = (weights.unsqueeze(1) * torch.sigmoid(self.A_values[topk_indices])).sum(0, keepdim=True)

        return modulated_tokens_list, A_val, beta_stack, scale_stack


class SimpleGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, h, adj):

        Wh = self.fc(h)  # [P, D']
        P = h.shape[0]

        a_input = torch.cat([Wh.repeat(1, P).view(P * P, -1), Wh.repeat(P, 1)], dim=1)  # [P*P, 2D']
        e = self.leaky_relu(self.attn_fc(a_input)).view(P, P)  # [P, P]

        e = e.masked_fill(adj == 0, -9e15)
        attention = torch.softmax(e, dim=1)  # [P, P]

        h_prime = torch.matmul(attention, Wh)  # [P, D']

        return h_prime, attention

