import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch_scatter import scatter_add, scatter_mean
import numpy as np

from GOOD.networks.models.GINs import GINFeatExtractor
from GOOD.networks.models.GINvirtualnode import vGINFeatExtractor

from .gnnconv import GNN_node


class MolecularSemanticPrototypeModule(nn.Module):
    def __init__(self, emb_dim, num_prototypes, temperature=1.0):
        super(MolecularSemanticPrototypeModule, self).__init__()
        self.num_prototypes = num_prototypes
        self.emb_dim = emb_dim
        self.temperature = temperature

        self.prototypes = nn.Parameter(torch.randn(num_prototypes, emb_dim))
        nn.init.xavier_uniform_(self.prototypes)
    
    def forward(self, u_G):

        scores = torch.matmul(u_G, self.prototypes.t()) / self.temperature  # [B, K]

        alpha = F.softmax(scores, dim=-1)  # [B, K]

        z_G = torch.matmul(alpha, self.prototypes)  # [B, d]
        
        return alpha, z_G


class MolecularEncoderModule(nn.Module):
    def __init__(self, args, config):
        super(MolecularEncoderModule, self).__init__()
        self.args = args
        
        if args.dataset.startswith('GOOD'):
            emb_dim = config.model.dim_hidden
            if config.model.model_name == 'GIN':
                self.gnn = GINFeatExtractor(config, without_readout=True)
            else:
                self.gnn = vGINFeatExtractor(config, without_readout=True)
        else:
            emb_dim = args.emb_dim
            self.gnn = GNN_node(num_layer=args.layer, emb_dim=args.emb_dim,
                                drop_ratio=args.dropout, gnn_type=args.gnn_type)
        
        self.emb_dim = emb_dim
        self.pool = global_mean_pool
    
    def forward(self, data):

        if self.args.dataset.startswith('GOOD'):
            node_feat = self.gnn(data=data)
        else:
            node_feat = self.gnn(data)

        u_G = self.pool(node_feat, data.batch)  # [B, d]
        
        return u_G


class SemanticConsistencyModule(nn.Module):
    def __init__(self, emb_dim, top_k=5, inter_temperature=1.0):
        super(SemanticConsistencyModule, self).__init__()
        self.emb_dim = emb_dim
        self.top_k = top_k
        self.inter_temperature = inter_temperature
    
    def compute_semantic_similarity(self, z_G, z_H):
        z_G_norm = F.normalize(z_G, p=2, dim=1)
        z_H_norm = F.normalize(z_H, p=2, dim=1)
        sim = torch.matmul(z_G_norm, z_H_norm.t())  # [B, B]
        return sim
    
    def get_topk_neighbors(self, z_G, k=None):
        if k is None:
            k = self.top_k
        batch_size = z_G.size(0)
        k = min(k, batch_size - 1)
        
        sim_matrix = self.compute_semantic_similarity(z_G, z_G)  # [B, B]

        mask = torch.eye(batch_size, device=z_G.device, dtype=torch.bool)
        sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))

        _, topk_indices = torch.topk(sim_matrix, k, dim=1)  # [B, k]
        return topk_indices


class ContrastiveProjectionHead(nn.Module):
    def __init__(self, emb_dim, proj_dim=None):
        super(ContrastiveProjectionHead, self).__init__()
        if proj_dim is None:
            proj_dim = emb_dim
        self.proj_head = nn.Sequential(
            nn.Linear(emb_dim, emb_dim * 2),
            nn.BatchNorm1d(emb_dim * 2),
            nn.ReLU(),
            nn.Linear(emb_dim * 2, proj_dim)
        )
    
    def forward(self, z):
        return self.proj_head(z)


class AdversarialPerturbationModule(nn.Module):
    def __init__(self, emb_dim, epsilon, inner_steps, inner_lr):
        super(AdversarialPerturbationModule, self).__init__()
        self.emb_dim = emb_dim
        self.epsilon = epsilon
        self.inner_steps = inner_steps
        self.inner_lr = inner_lr
    
    def project_to_constraint(self, delta_G, epsilon):
        norm = torch.norm(delta_G, p=2, dim=-1, keepdim=True)  # [B, 1]
        scale = torch.clamp(epsilon / (norm + 1e-8), max=1.0)
        return delta_G * scale
    
    def generate_perturbation(self, r_G_d, z_G_d, semantic_prototype, 
                             contrastive_head, temperature, device):
        delta_G = torch.zeros_like(r_G_d, requires_grad=True, device=device)
        
        for step in range(self.inner_steps):
            u_G_tilde = z_G_d + (r_G_d + delta_G)  # [B, d]

            _, z_G_tilde = semantic_prototype(u_G_tilde)  # [B, d]

            h_G = contrastive_head(z_G_d)
            h_tilde = contrastive_head(z_G_tilde)

            h_G_norm = F.normalize(h_G, p=2, dim=1)
            h_tilde_norm = F.normalize(h_tilde, p=2, dim=1)

            sim_matrix = torch.matmul(h_G_norm, h_tilde_norm.t()) / temperature  # [B, B]

            labels = torch.arange(sim_matrix.size(0), device=device)

            loss_adv = -F.cross_entropy(sim_matrix, labels)


            if delta_G.grad is not None:
                delta_G.grad.zero_()
            loss_adv.backward(retain_graph=(step < self.inner_steps - 1))

            with torch.no_grad():
                delta_G = delta_G + self.inner_lr * delta_G.grad

                delta_G = self.project_to_constraint(delta_G, self.epsilon)
                delta_G = delta_G.requires_grad_()
        
        return delta_G.detach()


class MyModel(nn.Module):
    def __init__(self, args, config):
        super(MyModel, self).__init__()
        self.args = args
        self.config = config

        if args.dataset.startswith('GOOD'):
            self.emb_dim = config.model.dim_hidden
        else:
            self.emb_dim = args.emb_dim

        self.molecular_encoder = MolecularEncoderModule(args, config)

        self.semantic_prototype = MolecularSemanticPrototypeModule(
            emb_dim=self.emb_dim,
            num_prototypes=args.num_prototypes,
            temperature=args.prototype_temperature
        )

        self.semantic_consistency = SemanticConsistencyModule(
            emb_dim=self.emb_dim,
            top_k=args.top_k,
            inter_temperature=args.inter_temperature
        )

        self.contrastive_head = ContrastiveProjectionHead(
            emb_dim=self.emb_dim,
            proj_dim=args.proj_dim if hasattr(args, 'proj_dim') else self.emb_dim
        )

        self.adv_perturb = AdversarialPerturbationModule(
            emb_dim=self.emb_dim,
            epsilon=args.epsilon,
            inner_steps=args.inner_steps,
            inner_lr=args.inner_lr
        )

        if args.dataset.startswith('GOOD'):
            num_classes = config.dataset.num_classes
            self.classifier = nn.Sequential(
                nn.Linear(self.emb_dim, num_classes)
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(self.emb_dim, self.emb_dim * 2),
                nn.BatchNorm1d(self.emb_dim * 2),
                nn.ReLU(),
                nn.Dropout(args.dropout),
                nn.Linear(self.emb_dim * 2, 1)
            )
    
    def forward(self, data, compute_adv=False):

        u_G = self.molecular_encoder(data)  # [B, d]
        

        alpha, z_G = self.semantic_prototype(u_G)  # alpha: [B, K], z_G: [B, d]
        r_G = u_G - z_G  # [B, d]

        logit = self.classifier(z_G)
        
        losses = {}
        
        z_G_tilde = None
        if compute_adv and self.training:
            z_G_d = z_G.detach()
            r_G_d = r_G.detach()

            contrastive_temp = self.args.contrastive_temperature if hasattr(self.args, 'contrastive_temperature') else 0.1
            delta_G = self.adv_perturb.generate_perturbation(
                r_G_d, z_G_d, self.semantic_prototype, 
                self.contrastive_head, contrastive_temp, data.x.device
            )

            u_G_tilde = z_G + (r_G + delta_G)  # [B, d]

            _, z_G_tilde = self.semantic_prototype(u_G_tilde)  # [B, d]

            losses['inv'] = self.compute_infonce_loss(z_G, z_G_tilde, contrastive_temp)
        else:
            losses['inv'] = torch.tensor(0.0, device=z_G.device)
        
        # 6. 语义正则化损失
        losses['intra'] = self.compute_intra_loss(u_G, z_G)
        losses['inter'] = self.compute_inter_loss(z_G)
        losses['orth'] = self.compute_orth_loss()
        
        return logit, z_G, u_G, r_G, losses
    
    def compute_infonce_loss(self, z_G, z_G_tilde, temperature=None):

        if temperature is None:
            temperature = self.args.contrastive_temperature if hasattr(self.args, 'contrastive_temperature') else 0.1
        
        batch_size = z_G.size(0)
        

        z_G_norm = F.normalize(z_G, p=2, dim=1)
        z_G_tilde_norm = F.normalize(z_G_tilde, p=2, dim=1)

        sim_matrix = torch.matmul(z_G_norm, z_G_tilde_norm.t()) / temperature  # [B, B]

        labels = torch.arange(batch_size, device=z_G.device)

        loss = F.cross_entropy(sim_matrix, labels)
        
        return loss
    
    def compute_intra_loss(self, u_G, z_G):

        recon_error = u_G - z_G  # [B, d]
        recon_loss = (recon_error ** 2).sum(dim=1).mean()  # [B] -> scalar
        

        prototypes = self.semantic_prototype.prototypes  # [K, d]
        prototype_reg = (prototypes ** 2).sum()
        

        lambda_mu = getattr(self.args, 'lambda_mu', 0.01)
        intra_loss = recon_loss + lambda_mu * prototype_reg
        
        return intra_loss
    
    def compute_inter_loss(self, z_G):
        batch_size = z_G.size(0)
        
        if batch_size < 2:
            return torch.tensor(0.0, device=z_G.device)
        

        topk_indices = self.semantic_consistency.get_topk_neighbors(
            z_G, k=self.args.top_k
        )  # [B, k]

        sim_matrix = self.semantic_consistency.compute_semantic_similarity(z_G, z_G)
        mask = torch.eye(batch_size, device=z_G.device, dtype=torch.bool)
        sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))
        

        topk_sim = torch.gather(
            sim_matrix, dim=1, index=topk_indices
        )  # [B, k]

        rho = self.args.inter_temperature
        weights = F.softmax(topk_sim / rho, dim=1)  # [B, k]

        z_G_expanded = z_G.unsqueeze(1)  # [B, 1, d]
        z_neighbors = z_G[topk_indices]  # [B, k, d]
        
        diff = z_G_expanded - z_neighbors  # [B, k, d]
        diff_norm_sq = (diff ** 2).sum(dim=2)  # [B, k]
        
        inter_loss = (weights * diff_norm_sq).sum(dim=1).mean()
        
        return inter_loss
    
    def compute_orth_loss(self):

        P_norm = F.normalize(self.semantic_prototype.prototypes, p=2, dim=1)  # [K, d]

        P_PT = torch.matmul(P_norm, P_norm.t())  # [K, K]

        I = torch.eye(self.args.num_prototypes, device=P_norm.device)
        orth_loss = ((P_PT - I) ** 2).sum()
        
        return orth_loss
