import time
import warnings
from copy import deepcopy
from typing import List, Optional, Tuple, Union

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree, to_dense_adj, to_scipy_sparse_matrix, from_scipy_sparse_matrix

from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.utils import wrapper


class PurificationGCN(nn.Module):
    """Graph Purification + GCN for robust node classification
    
    This model implements graph purification based on Jaccard/Cosine similarity
    filtering followed by standard GCN for node classification. It serves as
    a baseline defense method against adversarial attacks.
    
    The purification process:
    1. Compute similarity (Jaccard or Cosine) between connected nodes
    2. Remove edges with similarity below threshold
    3. Train GCN on the purified graph
    
    Parameters
    ----------
    in_channels : int
        the input dimensions of model
    out_channels : int
        the output dimensions of model
    hids : List[int], optional
        the number of hidden units for each hidden layer, by default [16]
    acts : List[str], optional
        the activation function for each hidden layer, by default ['relu']
    dropout : float, optional
        the dropout ratio of model, by default 0.5
    bias : bool, optional
        whether to use bias in the layers, by default True
    purification_method : str, optional
        purification method ('jaccard', 'cosine', 'both'), by default 'jaccard'
    jaccard_threshold : float, optional
        threshold for Jaccard similarity filtering, by default 0.03
    cosine_threshold : float, optional
        threshold for cosine similarity filtering, by default 0.1
    allow_singleton : bool, optional
        whether to allow singleton nodes after purification, by default False
    device : str, optional
        device to use, by default 'cpu'
        
    Examples
    --------
    >>> # PurificationGCN with Jaccard filtering
    >>> model = PurificationGCN(100, 10, purification_method='jaccard')
    
    >>> # PurificationGCN with Cosine filtering
    >>> model = PurificationGCN(100, 10, purification_method='cosine', 
    ...                        cosine_threshold=0.2)
    
    >>> # PurificationGCN with both filtering methods
    >>> model = PurificationGCN(100, 10, purification_method='both')
    """
    
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [128], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True,
                 purification_method: str = 'jaccard',
                 jaccard_threshold: float = 0.03,
                 cosine_threshold: float = 0.1,
                 allow_singleton: bool = False,
                 device: str = 'cpu'):
        
        super().__init__()
        
        # Model parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hids = hids
        self.acts = acts
        self.dropout = dropout
        self.bias = bias
        self.device = device
        
        # Purification parameters
        self.purification_method = purification_method
        self.jaccard_threshold = jaccard_threshold
        self.cosine_threshold = cosine_threshold
        self.allow_singleton = allow_singleton
        
        # Build GCN backbone
        self._build_gcn()
        
        # Purification statistics
        self.removed_edges = None
        self.original_num_edges = None
        self.purified_num_edges = None
        
    def _build_gcn(self):
        """Build GCN backbone following GreatX style"""
        conv = []
        assert len(self.hids) == len(self.acts)
        
        in_channels = self.in_channels
        for hid, act in zip(self.hids, self.acts):
            conv.append(GCNConv(in_channels, hid, bias=self.bias))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(self.dropout))
            in_channels = hid
            
        conv.append(GCNConv(in_channels, self.out_channels, bias=self.bias))
        self.gcn = Sequential(*conv)
        
    def reset_parameters(self):
        """Reset model parameters"""
        self.gcn.reset_parameters()
    
    def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
        """Forward pass through the model"""
        return self.gcn(x, edge_index, edge_weight)
    
    def fit(self, x: Tensor, edge_index: Adj, y: Tensor,
            train_mask: Tensor, val_mask: Tensor,
            edge_weight: OptTensor = None, epochs: int = 200,
            lr: float = 0.01, weight_decay: float = 5e-4, **kwargs):
        """Train the purification model
        
        Parameters
        ----------
        x : Tensor
            node features
        edge_index : Adj
            edge indices or adjacency matrix
        y : Tensor
            node labels
        train_mask : Tensor
            training mask
        val_mask : Tensor
            validation mask
        edge_weight : OptTensor, optional
            edge weights, by default None
        epochs : int, optional
            number of training epochs, by default 200
        lr : float, optional
            learning rate, by default 0.01
        weight_decay : float, optional
            weight decay, by default 5e-4
        """
        print(f"=== PurificationGCN Training Started ===")
        print(f"Device: {self.device}")
        print(f"Purification method: {self.purification_method}")
        print(f"Jaccard threshold: {self.jaccard_threshold}")
        print(f"Cosine threshold: {self.cosine_threshold}")
        print(f"Allow singleton: {self.allow_singleton}")
        
        # Move to device
        self.to(self.device)
        x, y = x.to(self.device), y.to(self.device)
        train_mask, val_mask = train_mask.to(self.device), val_mask.to(self.device)
        
        # Convert edge_index to proper format if needed
        if edge_index.dtype == torch.long:
            # Sparse format
            edge_index = edge_index.to(self.device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(self.device)
            self.original_num_edges = edge_index.size(1)
        else:
            # Dense format - convert to sparse
            edge_index = edge_index.to(self.device)
            edge_index, edge_weight = self._dense_to_sparse(edge_index)
            self.original_num_edges = edge_index.size(1)
        
        print(f"Original graph - Nodes: {x.shape[0]}, Edges: {self.original_num_edges}")
        print(f"Train nodes: {train_mask.sum().item()}, Val nodes: {val_mask.sum().item()}")
        
        # Apply purification
        print(f"\n=== Applying Graph Purification ===")
        purified_edge_index, purified_edge_weight = self._purify_graph(
            x, edge_index, edge_weight)
        
        self.purified_num_edges = purified_edge_index.size(1)
        removed_edges = self.original_num_edges - self.purified_num_edges
        removal_rate = removed_edges / self.original_num_edges * 100
        
        print(f"Purified graph - Edges: {self.purified_num_edges}")
        print(f"Removed edges: {removed_edges} ({removal_rate:.2f}%)")
        
        # Initialize optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        
        # Training variables
        best_val_acc = 0
        best_val_loss = float('inf')
        best_weights = None
        
        print(f"\n=== Training GCN on Purified Graph ===")
        t_total = time.time()
        
        for epoch in range(epochs):
            # Training
            self.train()
            optimizer.zero_grad()
            
            output = self.forward(x, purified_edge_index, purified_edge_weight)
            loss_train = F.cross_entropy(output[train_mask], y[train_mask])
            
            pred_train = output[train_mask].argmax(dim=1)
            acc_train = (pred_train == y[train_mask]).float().mean()
            
            loss_train.backward()
            optimizer.step()
            
            # Validation
            self.eval()
            with torch.no_grad():
                output = self.forward(x, purified_edge_index, purified_edge_weight)
                loss_val = F.cross_entropy(output[val_mask], y[val_mask])
                pred_val = output[val_mask].argmax(dim=1)
                acc_val = (pred_val == y[val_mask]).float().mean()
            
            # Save best model
            if acc_val > best_val_acc:
                best_val_acc = acc_val
                best_weights = deepcopy(self.state_dict())
            
            if loss_val < best_val_loss:
                best_val_loss = loss_val
                best_weights = deepcopy(self.state_dict())
            
            # Print progress
            if epoch % 20 == 0 or epoch == epochs - 1:
                print(f"Epoch {epoch + 1:3d}/{epochs} | "
                      f"Train Loss: {loss_train.item():.4f} | "
                      f"Train Acc: {acc_train.item():.4f} | "
                      f"Val Loss: {loss_val.item():.4f} | "
                      f"Val Acc: {acc_val.item():.4f}")
        
        total_time = time.time() - t_total
        print(f"\n=== Training Completed ===")
        print(f"Total time: {total_time:.4f}s")
        print(f"Best validation accuracy: {best_val_acc:.4f}")
        print(f"Best validation loss: {best_val_loss:.4f}")
        
        # Load best model
        if best_weights is not None:
            self.load_state_dict(best_weights)
            print("Loaded best model weights")
        
        # Store purified graph for later use
        self.purified_edge_index = purified_edge_index
        self.purified_edge_weight = purified_edge_weight
    
    def _purify_graph(self, x: Tensor, edge_index: Tensor, edge_weight: OptTensor = None):
        """Apply graph purification based on similarity metrics"""
        
        if self.purification_method == 'jaccard':
            return self._jaccard_purification(x, edge_index, edge_weight)
        elif self.purification_method == 'cosine':
            return self._cosine_purification(x, edge_index, edge_weight)
        elif self.purification_method == 'both':
            # Apply Jaccard first, then Cosine
            edge_index, edge_weight = self._jaccard_purification(x, edge_index, edge_weight)
            return self._cosine_purification(x, edge_index, edge_weight)
        else:
            raise ValueError(f"Unknown purification method: {self.purification_method}")
    
    def _jaccard_purification(self, x: Tensor, edge_index: Tensor, edge_weight: OptTensor = None):
        """Apply Jaccard similarity-based purification"""
        print(f"  Applying Jaccard purification (threshold: {self.jaccard_threshold})")
        
        row, col = edge_index
        A = x[row]  # Features of source nodes
        B = x[col]  # Features of target nodes
        
        # Compute Jaccard similarity
        intersection = torch.count_nonzero(A * B, dim=1)
        union = (torch.count_nonzero(A, dim=1) + 
                torch.count_nonzero(B, dim=1) - intersection)
        jaccard_sim = intersection.float() / (union.float() + 1e-7)
        
        # Compute node degrees for singleton check
        if not self.allow_singleton:
            deg = degree(row, num_nodes=x.size(0))
            
        # Create mask for edges to keep
        if self.allow_singleton:
            mask = jaccard_sim > self.jaccard_threshold
        else:
            mask = torch.logical_and(jaccard_sim > self.jaccard_threshold, deg[col] > 1)
        
        # Filter edges
        filtered_edge_index = edge_index[:, mask]
        filtered_edge_weight = edge_weight[mask] if edge_weight is not None else None
        
        removed_count = edge_index.size(1) - filtered_edge_index.size(1)
        print(f"  Jaccard: Removed {removed_count} edges")
        
        return filtered_edge_index, filtered_edge_weight
    
    def _cosine_purification(self, x: Tensor, edge_index: Tensor, edge_weight: OptTensor = None):
        """Apply Cosine similarity-based purification"""
        print(f"  Applying Cosine purification (threshold: {self.cosine_threshold})")
        
        row, col = edge_index
        A = x[row]  # Features of source nodes
        B = x[col]  # Features of target nodes
        
        # Compute cosine similarity
        cosine_sim = F.cosine_similarity(A, B, dim=1)
        
        # Compute node degrees for singleton check
        if not self.allow_singleton:
            deg = degree(row, num_nodes=x.size(0))
            
        # Create mask for edges to keep
        if self.allow_singleton:
            mask = cosine_sim > self.cosine_threshold
        else:
            mask = torch.logical_and(cosine_sim > self.cosine_threshold, deg[col] > 1)
        
        # Filter edges
        filtered_edge_index = edge_index[:, mask]
        filtered_edge_weight = edge_weight[mask] if edge_weight is not None else None
        
        removed_count = edge_index.size(1) - filtered_edge_index.size(1)
        print(f"  Cosine: Removed {removed_count} edges")
        
        return filtered_edge_index, filtered_edge_weight
    
    def _dense_to_sparse(self, adj: Tensor):
        """Convert dense adjacency matrix to sparse edge_index format"""
        edge_index = adj.nonzero().t().contiguous()
        edge_weight = adj[edge_index[0], edge_index[1]]
        return edge_index, edge_weight
    
    def test(self, x: Tensor, y: Tensor, test_mask: Tensor) -> float:
        """Test the model performance"""
        self.eval()
        with torch.no_grad():
            if hasattr(self, 'purified_edge_index'):
                output = self.forward(x, self.purified_edge_index, self.purified_edge_weight)
            else:
                # If not trained yet, use original graph (should not happen in normal usage)
                warnings.warn("Model not trained yet, using original graph structure")
                output = self.forward(x, edge_index, edge_weight)
            
            pred = output[test_mask].argmax(dim=1)
            acc = (pred == y[test_mask]).float().mean()
            return acc.item()
    
    def fit_with_trainer(self, trainer, data, train_mask, val_mask, epochs=200):
        """Fit method compatible with Trainer"""
        # Convert data to appropriate format
        x, y = data.x, data.y
        if hasattr(data, 'edge_index'):
            edge_index = data.edge_index
            edge_weight = getattr(data, 'edge_weight', None)
        else:
            raise ValueError("Data must have edge_index attribute")
        
        # Call the original fit method
        self.fit(x, edge_index, y, train_mask, val_mask, edge_weight, epochs)
        
        # Return self for chaining
        return self
    
    def get_purification_stats(self):
        """Get purification statistics"""
        if self.original_num_edges is None or self.purified_num_edges is None:
            return None
        
        removed_edges = self.original_num_edges - self.purified_num_edges
        removal_rate = removed_edges / self.original_num_edges * 100
        
        return {
            'original_edges': self.original_num_edges,
            'purified_edges': self.purified_num_edges,
            'removed_edges': removed_edges,
            'removal_rate': removal_rate
        } 