import warnings

import torch
import torch.nn.functional as F
from torch_geometric.utils import scatter

EPS = 1e-10

# Hardcoded dual thresholds based on similarity analysis
DUALAUTO_TRAIN_THRESHOLDS = {
    "arxiv": 0.320,     # Conservative: 0.320
    "cora": 0.250,      # Conservative: 0.250
    "pubmed": 0.230,    # Conservative: 0.230
    "citeseer": 0.120,  # Conservative: 0.120
    "wikics": 0.283,    # Conservative: 0.283
    "reddit": 0.000,    # Conservative: 0.000 (no filtering during training)
    "instagram": 0.000, # Conservative: 0.000
    "photo": 0.000,     # Conservative: 0.000
    "history": 0.260,   # Conservative: 0.260
    "computer": 0.000   # Conservative: 0.000 (no filtering during training)
}

DUALAUTO_TEST_THRESHOLDS = {
    "arxiv": 0.540,     # Balanced: 0.540 ± 0.000
    "cora": 0.503,      # Balanced: 0.503 ± 0.009
    "pubmed": 0.633,    # Balanced: 0.633 ± 0.005
    "citeseer": 0.580,  # Balanced: 0.580 ± 0.008
    "wikics": 0.433,    # Balanced: 0.433 ± 0.012
    "reddit": 0.177,    # Percentile-based: 0.177 (avoids extreme 0.850)
    "instagram": 0.471, # Percentile-based: 0.471 (avoids extreme 0.853)
    "photo": 0.476,     # Percentile-based: 0.476
    "history": 0.457,   # Balanced: 0.457 ± 0.005
    "computer": 0.457   # Balanced: 0.457 ± 0.005
}


class GUARDDUAL(torch.nn.Module):
    r"""Implementation of GUARD-DUAL: A dual-threshold variant of GNNGUARD
    that uses conservative thresholds during training and balanced thresholds during testing.

    Parameters
    ----------
    dataset_name : str
        Name of the dataset (e.g., 'cora', 'citeseer') to load appropriate thresholds
    add_self_loops : bool, optional
        whether to add self-loops to the input graph, by default False
    training : bool, optional
        whether the model is in training mode, affects threshold selection, by default True
    train_mask : torch.Tensor, optional
        Boolean mask indicating training nodes, by default None
    val_mask : torch.Tensor, optional  
        Boolean mask indicating validation nodes, by default None
    """
    def __init__(self, dataset_name: str, add_self_loops: bool = False, training: bool = True,
                 train_mask: torch.Tensor = None, val_mask: torch.Tensor = None):
        super().__init__()
        self.dataset_name = dataset_name
        self.add_self_loops = add_self_loops
        self.training_mode = training
        self.train_mask = train_mask
        self.val_mask = val_mask
        
        # Load dual thresholds based on dataset
        if dataset_name not in DUALAUTO_TRAIN_THRESHOLDS:
            raise ValueError(f"Dataset '{dataset_name}' not supported by GUARDDUAL. "
                           f"Supported datasets: {list(DUALAUTO_TRAIN_THRESHOLDS.keys())}")
        
        self.train_threshold = DUALAUTO_TRAIN_THRESHOLDS[dataset_name]
        self.test_threshold = DUALAUTO_TEST_THRESHOLDS[dataset_name]
        
        # Current threshold depends on mode
        self.current_threshold = self.train_threshold if training else self.test_threshold
    
    def set_split_info(self, train_mask: torch.Tensor, val_mask: torch.Tensor):
        """Set node split masks for edge preservation during test"""
        self.train_mask = train_mask
        self.val_mask = val_mask

    def set_training_mode(self, training: bool):
        """Set training mode and update threshold accordingly"""
        self.training_mode = training
        self.current_threshold = self.train_threshold if training else self.test_threshold
        
    def train(self, mode: bool = True):
        """Override train method to update threshold"""
        super().train(mode)
        self.set_training_mode(mode)
        return self

    def eval(self):
        """Override eval method to update threshold"""
        super().eval()
        self.set_training_mode(False)
        return self

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        if edge_weight is not None:
            warnings.warn("`edge_weight` is supported in GUARDDUAL "
                          "and will be ignored for computation.")

        row, col = edge_index
        A, B = x[row], x[col]
        att_score = F.cosine_similarity(A, B)
        
        # Apply different filtering logic based on training mode
        if not self.training_mode and self.train_mask is not None and self.val_mask is not None:
            # Test mode: preserve train/val edges, filter only test-related edges
            trainval_mask = self.train_mask | self.val_mask
            
            # Create mask for train/val edges (easy threshold)
            trainval_edge_mask = (trainval_mask[row] & trainval_mask[col]) & (att_score >= self.train_threshold)
            
            # Create mask for test-related edges (filtered by threshold)
            test_related_mask = ~trainval_edge_mask
            test_edge_similarity_mask = att_score >= self.current_threshold
            
            # Combine masks: keep all train/val edges + filtered test-related edges
            mask = trainval_edge_mask | (test_related_mask & test_edge_similarity_mask)
        else:
            # Training mode: use threshold filtering
            mask = att_score >= self.current_threshold
        
        edge_index = edge_index[:, mask]
        att_score = att_score[mask]

        row, col = edge_index
        row_sum = scatter(att_score, col, dim_size=x.size(0))
        att_score_norm = att_score / (row_sum[row] + EPS)

        if self.add_self_loops:
            degree = scatter(torch.ones_like(att_score_norm), col,
                             dim_size=x.size(0))
            self_weight = 1.0 / (degree + 1)
            att_score_norm = torch.cat([att_score_norm, self_weight])
            loop_index = torch.arange(0, x.size(0), dtype=torch.long,
                                      device=edge_index.device)
            loop_index = loop_index.unsqueeze(0).repeat(2, 1)
            edge_index = torch.cat([edge_index, loop_index], dim=1)

        att_score_norm = att_score_norm.exp()
        return edge_index, att_score_norm

    def extra_repr(self) -> str:
        return f"dataset={self.dataset_name}, train_threshold={self.train_threshold}, test_threshold={self.test_threshold}, current={self.current_threshold:.3f}"