import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj, to_dense_batch

def consistent_topology_loss_with_fused_sim(fused_similarity, target_adj, tau=0.1, percentile=80):
    
    feat_sim_norm = (fused_similarity + 1) / 2
    num_nodes = feat_sim_norm.size(0)
    
    # compute threshold for top k% similar pairs
    if num_nodes > 1000:
        sample_size = min(5000, num_nodes * (num_nodes-1) // 2)
        idx1 = torch.randint(0, num_nodes, (sample_size,), device=fused_similarity.device)
        idx2 = torch.randint(0, num_nodes, (sample_size,), device=fused_similarity.device)
        valid_mask = idx1 != idx2
        if valid_mask.sum() > 100:
            sampled_sim = feat_sim_norm[idx1[valid_mask], idx2[valid_mask]]
            feat_thresh = torch.quantile(sampled_sim, percentile/100.0)
        else:
            feat_thresh = torch.tensor(0.8, device=fused_similarity.device)
    else:
        mask = ~torch.eye(num_nodes, device=feat_sim_norm.device, dtype=bool)
        feat_sim_flat = feat_sim_norm[mask]
        feat_thresh = torch.quantile(feat_sim_flat, percentile/100.0)
    
    feat_similar = feat_sim_norm > feat_thresh
    topo_connected = target_adj > 0
    
    # get consistent mask
    consistent_positive = feat_similar & topo_connected
    consistent_negative = (~feat_similar) & (~topo_connected)
    constraint_mask = consistent_positive | consistent_negative
    target_consistent = torch.where(constraint_mask, target_adj, torch.zeros_like(target_adj))
    
    feat_prob = torch.sigmoid(fused_similarity / tau)
    loss = F.binary_cross_entropy(feat_prob, target_consistent, reduction='none')
    masked_loss = loss * constraint_mask.float()
    
    return masked_loss.sum() / (constraint_mask.sum() + 1e-8)


class CrossViewContrastiveLoss(nn.Module):
    def __init__(self, tau=0.1):
        super().__init__()
        self.tau = tau
        self.refl_mask = None
        self.between_mask = None
        
    def build_positive_mask(self, edge_index, num_nodes):
        device = edge_index.device
        adj_matrix = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0]
        eye_matrix = torch.eye(num_nodes, device=device)
        
        self.refl_mask = adj_matrix - eye_matrix # neighbors only
        self.refl_mask = (self.refl_mask > 0).float()
        
        self.between_mask = (adj_matrix > 0).float() # cross-view neighbors
        self.between_mask.fill_diagonal_(1.0) # cross-view same node

        return self.refl_mask, self.between_mask
    
    def sim(self, z1, z2):
        z1 = F.normalize(z1, dim=-1)
        z2 = F.normalize(z2, dim=-1)
        return torch.mm(z1, z2.t())
    
    def semi_loss(self, z1, z2):
        f = lambda x: torch.exp(x / self.tau)
        refl_sim = f(self.sim(z1, z1))
        between_sim = f(self.sim(z1, z2))
        
        refl_pos = (refl_sim * self.refl_mask).sum(dim=1)
        between_pos = (between_sim * self.between_mask).sum(dim=1)
        pos_sim = refl_pos + between_pos
        
        denominator = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()
        
        return -torch.log(pos_sim / (denominator + 1e-8))
    
    def batched_semi_loss(self, z1, z2, batch_size):
        N = z1.size(0)
        losses = []
        for start_idx in range(0, N, batch_size):
            end_idx = min(start_idx + batch_size, N)
            z1_batch = z1[start_idx:end_idx]
            f = lambda x: torch.exp(x / self.tau)
            refl_sim_batch = f(self.sim(z1_batch, z1))
            between_sim_batch = f(self.sim(z1_batch, z2))
            refl_mask_batch = self.refl_mask[start_idx:end_idx]
            between_mask_batch = self.between_mask[start_idx:end_idx]
            refl_pos_batch = (refl_sim_batch * refl_mask_batch).sum(dim=1)
            between_pos_batch = (between_sim_batch * between_mask_batch).sum(dim=1)
            pos_sim_batch = refl_pos_batch + between_pos_batch
            denominator_batch = (refl_sim_batch.sum(1) + between_sim_batch.sum(1) - 
                               torch.diag(f(self.sim(z1_batch, z1_batch))))
            batch_loss = -torch.log(pos_sim_batch / (denominator_batch + 1e-8))
            losses.append(batch_loss)
        return torch.cat(losses)
    
    def forward(self, frozen_output, tuned_output, batch_size=0):
        if self.refl_mask is None or self.between_mask is None:
            raise ValueError("Please build_positive_mask() first.")
        if batch_size == 0:
            l1 = self.semi_loss(frozen_output, tuned_output)
            l2 = self.semi_loss(tuned_output, frozen_output)
        else:
            l1 = self.batched_semi_loss(frozen_output, tuned_output, batch_size)
            l2 = self.batched_semi_loss(tuned_output, frozen_output, batch_size)
        ret = (l1 + l2) * 0.5
        return ret.mean()


class DualBranchFramework(nn.Module):
    def __init__(self, hidden_dim, mode='adaptive', alpha_init=0.5):
        super().__init__()
        self.mode = mode  # 'adaptive', 'sum', 'concat'
        self.alpha = nn.Parameter(torch.tensor(alpha_init))
        if mode == 'concat':
            self.concat_mapper = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def _compute_similarity_matrix(self, features):
        features_norm = F.normalize(features, dim=-1)
        similarity_matrix = torch.mm(features_norm, features_norm.t())
        return similarity_matrix

    def forward(self, frozen_output, prompt_output):
        frozen_sim = self._compute_similarity_matrix(frozen_output)
        prompt_sim = self._compute_similarity_matrix(prompt_output)

        if self.mode == 'sum':
            fused_output = frozen_output + prompt_output
            fused_sim = frozen_sim + prompt_sim

        elif self.mode == 'concat':
            concat_feat = torch.cat([frozen_output, prompt_output], dim=-1)
            fused_output = self.concat_mapper(concat_feat)
            fused_sim = self._compute_similarity_matrix(fused_output)

        else:  # adaptive
            fused_output = self.alpha * frozen_output + (1 - self.alpha) * prompt_output
            fused_sim = self.alpha * frozen_sim + (1 - self.alpha) * prompt_sim

        return fused_sim, fused_output, self.alpha
    
class PromptModule(nn.Module):
    def __init__(self, hidden_dim, num_layers, bottleneck_dim=32, alpha_init=0.1):
        super().__init__()
        self.adapters = nn.ModuleList()
        self.betas = nn.ParameterList()
        
        for _ in range(num_layers):
            self.adapters.append(nn.Sequential(
                nn.Linear(hidden_dim, bottleneck_dim),
                nn.ReLU(),
                nn.Linear(bottleneck_dim, hidden_dim)
            ))
            
            self.betas.append(nn.Parameter(torch.tensor(alpha_init)))

    def forward(self, x, edge_index, gnn_model, edge_weight=None):
        xs = []
        num_layers = len(gnn_model.convs)
        
        iterator = zip(gnn_model.convs, self.adapters, self.betas)
        
        for i, (conv, adapter, beta) in enumerate(iterator):
            if edge_weight is None:
                x_gnn = conv(x, edge_index)
            else:
                x_gnn = conv(x, edge_index, edge_weight)
            
            x_adapter = adapter(x_gnn)
            
            # x = GNN_out + alpha * Adapter_out
            x = x_gnn + beta * x_adapter
            
            if i < num_layers - 1:
                if hasattr(gnn_model, 'act'):
                    x = gnn_model.act(x)
                else:
                    x = F.relu(x)
                
                x = F.dropout(x, p=gnn_model.dropout, training=self.training)
            
            xs.append(x)

        if gnn_model.jk_mode == "last":
            return xs[-1]
        elif gnn_model.jk_mode == "list":
            return xs
        elif gnn_model.jk_mode in {"cat", "max"}:
            return gnn_model.jk(xs)
        else:
            raise RuntimeError(f"Unknown jk mode: {gnn_model.jk_mode}.")