import time
import warnings
from copy import deepcopy
from typing import List, Optional, Tuple, Union

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
from sklearn.metrics.pairwise import cosine_similarity

from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.utils import wrapper


class ContrastiveEmbedding(nn.Module):
    """Deep Graph Infomax (DGI) for contrastive learning embeddings"""
    
    def __init__(self, n_in: int, n_h: int, activation: str = 'prelu'):
        super().__init__()
        self.gcn = GCN_DGI(n_in, n_h, activation)
        self.read = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, seq3, seq4, adj, aug_adj1, aug_adj2, 
                sparse, msk, samp_bias1, samp_bias2, aug_type):
        h_0 = self.gcn(seq1, adj, sparse)
        
        if aug_type == 'edge':
            h_1 = self.gcn(seq1, aug_adj1, sparse)
            h_3 = self.gcn(seq1, aug_adj2, sparse)
        elif aug_type == 'mask':
            h_1 = self.gcn(seq3, adj, sparse)
            h_3 = self.gcn(seq4, adj, sparse)
        elif aug_type == 'node' or aug_type == 'subgraph':
            h_1 = self.gcn(seq3, aug_adj1, sparse)
            h_3 = self.gcn(seq4, aug_adj2, sparse)
        else:
            raise ValueError(f"Unknown aug_type: {aug_type}")

        c_1 = self.read(h_1, msk)
        c_1 = self.sigm(c_1)

        c_3 = self.read(h_3, msk)
        c_3 = self.sigm(c_3)

        h_2 = self.gcn(seq2, adj, sparse)

        ret1 = self.disc(c_1, h_0, h_2, samp_bias1, samp_bias2)
        ret2 = self.disc(c_3, h_0, h_2, samp_bias1, samp_bias2)

        return ret1 + ret2

    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        c = self.read(h_1, msk)
        return h_1.detach(), c.detach()


class GCN_DGI(nn.Module):
    """GCN layer for DGI"""
    
    def __init__(self, in_ft: int, out_ft: int, act: str, bias: bool = True):
        super().__init__()
        self.fc = nn.Linear(in_ft, out_ft, bias=False)
        self.act = nn.PReLU() if act == 'prelu' else getattr(nn, act, nn.ReLU)()

        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_ft))
            self.bias.data.fill_(0.0)
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
    
    def reset_parameters(self):
        if hasattr(self.fc, 'weight'):
            torch.nn.init.xavier_uniform_(self.fc.weight.data)
        if self.bias is not None:
            self.bias.data.fill_(0.0)
    
    def forward(self, seq, adj, sparse=False):
        seq_fts = self.fc(seq)
        if sparse:
            out = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq_fts, 0)), 0)
        else:
            out = torch.bmm(adj, seq_fts)
        
        if self.bias is not None:
            out += self.bias

        return self.act(out)


class AvgReadout(nn.Module):
    """Average readout layer"""
    
    def __init__(self):
        super().__init__()

    def forward(self, seq, msk):
        if msk is None:
            return torch.mean(seq, 1)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 1) / torch.sum(msk)


class Discriminator(nn.Module):
    """Discriminator for DGI"""
    
    def __init__(self, n_h: int):
        super().__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.f_k.weight.data)
        if self.f_k.bias is not None:
            self.f_k.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)
        
        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)

        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2

        logits = torch.cat((sc_1, sc_2), 1)
        return logits


class Stable(nn.Module):
    """Stable GNN: Spectral-based Graph Neural Network for Robust Node Classification
    
    This is a GreatX-style implementation of the Stable algorithm that combines:
    1. Graph preprocessing with Jaccard/Cosine similarity filtering
    2. Contrastive learning embeddings via Deep Graph Infomax (DGI)
    3. Reliable neighbor selection and graph reconstruction
    4. Standard GCN backbone for final classification
    
    Parameters
    ----------
    in_channels : int
        the input dimensions of model
    out_channels : int
        the output dimensions of model
    hids : List[int], optional
        the number of hidden units for each hidden layer, by default [16]
    acts : List[str], optional
        the activation function for each hidden layer, by default ['relu']
    dropout : float, optional
        the dropout ratio of model, by default 0.5
    bias : bool, optional
        whether to use bias in the layers, by default True
    jaccard_threshold : float, optional
        threshold for Jaccard similarity filtering, by default 0.03
    cosine_threshold : float, optional
        threshold for cosine similarity filtering, by default 0.1
    k_neighbors : int, optional
        number of reliable neighbors to add, by default 3
    degree_threshold : float, optional
        degree threshold for reliable neighbor selection, by default 1.0
    alpha : float, optional
        normalization parameter for adjacency matrix, by default 0.3
    beta : float, optional
        self-loop weight for contrastive learning, by default 2.0
    contrastive_epochs : int, optional
        number of epochs for contrastive learning, by default 1000
    contrastive_lr : float, optional
        learning rate for contrastive learning, by default 0.001
    recover_percent : float, optional
        percentage of edges to recover in augmentation, by default 0.2
    device : str, optional
        device to use, by default 'cpu'
        
    Examples
    --------
    >>> # Stable with default parameters
    >>> model = Stable(100, 10)
    
    >>> # Stable with custom architecture and parameters
    >>> model = Stable(100, 10, hids=[32, 16], acts=['relu', 'elu'],
    ...                jaccard_threshold=0.05, k_neighbors=5)
    """
    
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [128], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True,
                 jaccard_threshold: float = 0.03, cosine_threshold: float = 0.1,
                 k_neighbors: int = 7, degree_threshold: float = 1.0,
                 alpha: float = 0.3, beta: float = 2.0,
                 contrastive_epochs: int = 1000, contrastive_lr: float = 0.001,
                 recover_percent: float = 0.2, device: str = 'cpu'):
        
        super().__init__()
        
        # Model parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hids = hids
        self.acts = acts
        self.dropout = dropout
        self.bias = bias
        self.device = device
        
        # Stable-specific parameters
        self.jaccard_threshold = jaccard_threshold
        self.cosine_threshold = cosine_threshold
        self.k_neighbors = k_neighbors
        self.degree_threshold = degree_threshold
        self.alpha = alpha
        self.beta = beta
        self.contrastive_epochs = contrastive_epochs
        self.contrastive_lr = contrastive_lr
        self.recover_percent = recover_percent
        
        # Training state
        self.learned_embeddings = None
        self.cleaned_adj = None
        self.gcn = None
        
    def reset_parameters(self):
        """Reset model parameters"""
        if hasattr(self, 'gcn') and self.gcn is not None:
            self.gcn.reset_parameters()
    
    def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
        """Forward pass using learned embeddings and cleaned adjacency"""
        if self.learned_embeddings is None or self.cleaned_adj is None:
            raise RuntimeError("Model must be fitted before forward pass")
        
        # Use learned embeddings and cleaned adjacency
        return self.gcn(self.learned_embeddings, self.cleaned_adj, edge_weight)
    
    def fit(self, x: Tensor, edge_index: Adj, y: Tensor,
            train_mask: Tensor, val_mask: Tensor,
            edge_weight: OptTensor = None, **kwargs):
        """Train Stable model
        
        Parameters
        ----------
        x : Tensor
            node features
        edge_index : Adj
            adjacency matrix (dense tensor format)
        y : Tensor
            node labels
        train_mask : Tensor
            training mask
        val_mask : Tensor
            validation mask
        edge_weight : OptTensor, optional
            edge weights, by default None
        """
        # Convert to appropriate format
        if edge_index.dtype == torch.long:
            # Convert sparse to dense
            num_nodes = x.shape[0]
            adj = torch.zeros(num_nodes, num_nodes, device=self.device)
            adj[edge_index[0], edge_index[1]] = 1.0
            if edge_weight is not None:
                adj[edge_index[0], edge_index[1]] = edge_weight
        else:
            adj = edge_index.to(self.device)
        
        # Move to device
        x, y = x.to(self.device), y.to(self.device)
        train_mask, val_mask = train_mask.to(self.device), val_mask.to(self.device)
        adj = adj.to(self.device)
        
        # Step 1: Preprocess adjacency matrix with Jaccard similarity
        adj_sparse = self._to_scipy_sparse(adj)
        x_sparse = self._to_scipy_sparse(x)
        
        print("Step 1: Preprocessing adjacency matrix...")
        adj_preprocessed = self._preprocess_adj(x_sparse, adj_sparse, threshold=self.jaccard_threshold, jaccard=True)
        adj_deleted = adj_sparse - adj_preprocessed
        
        # Step 2: Learn contrastive embeddings
        print("Step 2: Learning contrastive embeddings...")
        x_tensor = torch.FloatTensor(x_sparse.todense()).to(self.device)
        embeddings = self._get_contrastive_embeddings(adj_preprocessed, x_tensor, adj_deleted)
        
        # Step 3: Clean adjacency matrix using learned embeddings
        print("Step 3: Cleaning adjacency matrix...")
        # Convert embeddings to dense numpy array for preprocessing
        if embeddings.is_sparse:
            embeddings_np = embeddings.to_dense().cpu().numpy()
        else:
            embeddings_np = embeddings.cpu().numpy()
        
        # Ensure embeddings_np is 2D
        if embeddings_np.ndim == 1:
            embeddings_np = embeddings_np.reshape(1, -1)
        
        # Convert to scipy sparse for consistency with preprocessing function
        embeddings_sparse = sp.csr_matrix(embeddings_np)
        
        adj_cleaned = self._preprocess_adj(embeddings_sparse, adj_sparse, threshold=self.cosine_threshold, jaccard=False)
        
        # Step 4: Add reliable neighbors
        print("Step 4: Adding reliable neighbors...")
        adj_cleaned_tensor = self._sparse_to_dense_tensor(adj_cleaned).to(self.device)
        self._add_reliable_neighbors(adj_cleaned_tensor, embeddings)
        
        # Step 5: Normalize adjacency matrix
        adj_normalized = self._normalize_adj(adj_cleaned_tensor, self.alpha)
        
        # Step 6: Build and train GCN
        print("Step 5: Training GCN...")
        embedding_dim = embeddings.shape[1]
        self._build_final_gcn(embedding_dim)
        
        # Store learned components
        self.learned_embeddings = embeddings
        self.cleaned_adj = adj_normalized
        
        # Train the GCN
        self._train_gcn(embeddings, adj_normalized, y, train_mask, val_mask)
        
        print("Stable model training completed!")
    
    def _build_final_gcn(self, embedding_dim: int):
        """Build final GCN with learned embedding dimension"""
        conv = []
        assert len(self.hids) == len(self.acts)
        
        in_channels = embedding_dim
        for hid, act in zip(self.hids, self.acts):
            conv.append(GCNConv(in_channels, hid, bias=self.bias, normalize=False))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(self.dropout))
            in_channels = hid
        
        conv.append(GCNConv(in_channels, self.out_channels, bias=self.bias, normalize=False))
        self.gcn = Sequential(*conv)
        self.gcn = self.gcn.to(self.device)
    
    def _get_contrastive_embeddings(self, adj_preprocessed, features, adj_deleted):
        """Learn contrastive embeddings using DGI"""
        features = features.unsqueeze(0)  # Add batch dimension
        ft_size = features.shape[2]
        nb_nodes = features.shape[1]
        
        # Create augmented adjacencies
        aug_adj1 = self._augment_adj(adj_preprocessed, adj_deleted, self.recover_percent)
        aug_adj2 = self._augment_adj(adj_preprocessed, adj_deleted, self.recover_percent)
        
        # Normalize adjacencies
        adj_norm = self._normalize_adj_sparse(adj_preprocessed + sp.eye(adj_preprocessed.shape[0]) * self.beta)
        aug_adj1_norm = self._normalize_adj_sparse(aug_adj1 + sp.eye(aug_adj1.shape[0]) * self.beta)
        aug_adj2_norm = self._normalize_adj_sparse(aug_adj2 + sp.eye(aug_adj2.shape[0]) * self.beta)
        
        # Convert to sparse tensors
        sp_adj = self._sparse_to_sparse_tensor(adj_norm).to(self.device)
        sp_aug_adj1 = self._sparse_to_sparse_tensor(aug_adj1_norm).to(self.device)
        sp_aug_adj2 = self._sparse_to_sparse_tensor(aug_adj2_norm).to(self.device)
        
        # Initialize DGI model
        model = ContrastiveEmbedding(ft_size, 512, 'prelu').to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.contrastive_lr, weight_decay=0.0)
        
        features = features.to(self.device)
        b_xent = nn.BCEWithLogitsLoss()
        
        best_loss = float('inf')
        patience = 20
        wait = 0
        
        for epoch in range(self.contrastive_epochs):
            model.train()
            optimizer.zero_grad()
            
            # Shuffle features for negative samples
            idx = np.random.permutation(nb_nodes)
            shuf_fts = features[:, idx, :]
            
            # Create labels
            lbl_1 = torch.ones(1, nb_nodes).to(self.device)
            lbl_2 = torch.zeros(1, nb_nodes).to(self.device)
            lbl = torch.cat((lbl_1, lbl_2), 1)
            
            # Forward pass
            logits = model(features, shuf_fts, features, features,
                          sp_adj, sp_aug_adj1, sp_aug_adj2,
                          True, None, None, None, aug_type='edge')
            
            loss = b_xent(logits, lbl)
            
            if epoch % 100 == 0:
                print(f'Contrastive learning epoch {epoch}, Loss: {loss.item():.4f}')
            
            if loss < best_loss:
                best_loss = loss
                wait = 0
                best_model = deepcopy(model.state_dict())
            else:
                wait += 1
                if wait >= patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
            
            loss.backward()
            optimizer.step()
        
        # Load best model and extract embeddings
        model.load_state_dict(best_model)
        model.eval()
        
        with torch.no_grad():
            embeddings, _ = model.embed(features, sp_adj, True, None)
        
        return embeddings.squeeze(0)
    
    def _train_gcn(self, embeddings, adj, y, train_mask, val_mask):
        """Train the final GCN classifier"""
        optimizer = torch.optim.Adam(self.gcn.parameters(), lr=0.01, weight_decay=5e-4)
        
        best_val_acc = 0
        best_model = None
        patience = 100
        wait = 0
        
        for epoch in range(200):
            self.gcn.train()
            optimizer.zero_grad()
            
            logits = self.gcn(embeddings, adj)
            loss = F.cross_entropy(logits[train_mask], y[train_mask])
            
            loss.backward()
            optimizer.step()
            
            # Validation
            if epoch % 10 == 0:
                self.gcn.eval()
                with torch.no_grad():
                    val_logits = self.gcn(embeddings, adj)
                    val_pred = val_logits[val_mask].argmax(dim=1)
                    val_acc = (val_pred == y[val_mask]).float().mean()
                
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_model = deepcopy(self.gcn.state_dict())
                    wait = 0
                else:
                    wait += 1
                    if wait >= patience:
                        break
                
                print(f'Epoch {epoch}, Loss: {loss.item():.4f}, Val Acc: {val_acc:.4f}')
        
        if best_model is not None:
            self.gcn.load_state_dict(best_model)
    
    def test(self, x: Tensor, y: Tensor, test_mask: Tensor) -> float:
        """Test the model performance"""
        self.gcn.eval()
        with torch.no_grad():
            logits = self.gcn(self.learned_embeddings, self.cleaned_adj)
            pred = logits[test_mask].argmax(dim=1)
            acc = (pred == y[test_mask]).float().mean()
            return acc.item()
    
    # Utility methods
    def _to_scipy_sparse(self, tensor):
        """Convert tensor to scipy sparse matrix"""
        if tensor.is_sparse:
            indices = tensor._indices().cpu().numpy()
            values = tensor._values().cpu().numpy()
            return sp.csr_matrix((values, indices), shape=tensor.shape)
        else:
            # Handle both 2D and 1D tensors
            if tensor.dim() == 1:
                tensor = tensor.unsqueeze(0)
            return sp.csr_matrix(tensor.cpu().numpy())
    
    def _sparse_to_dense_tensor(self, sparse_mx):
        """Convert scipy sparse matrix to dense tensor"""
        return torch.FloatTensor(sparse_mx.todense())
    
    def _sparse_to_sparse_tensor(self, sparse_mx):
        """Convert scipy sparse matrix to torch sparse tensor"""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse_coo_tensor(indices, values, shape)
    
    def _preprocess_adj(self, features, adj, threshold=0.03, jaccard=True):
        """Preprocess adjacency matrix by removing dissimilar edges"""
        if not sp.issparse(adj):
            adj = sp.csr_matrix(adj)

        adj_triu = sp.triu(adj, format='csr')

        if sp.issparse(features):
            features = features.todense()
        
        # Convert to numpy array if it's a matrix
        if hasattr(features, 'A'):
            features = features.A
        
        removed_cnt = 0
        if jaccard:
            removed_cnt = self._dropedge_jaccard(adj_triu, features, threshold)
        else:
            removed_cnt = self._dropedge_cosine(adj_triu, features, threshold)
        
        print(f'Removed {removed_cnt} edges in preprocessing')
        return adj_triu + adj_triu.transpose()
    
    def _dropedge_jaccard(self, adj_triu, features, threshold):
        """Remove edges based on Jaccard similarity"""
        removed_cnt = 0
        data, indices, indptr = adj_triu.data, adj_triu.indices, adj_triu.indptr
        
        for row in range(len(indptr) - 1):
            for i in range(indptr[row], indptr[row + 1]):
                n1 = row
                n2 = indices[i]
                
                # Ensure features are 2D
                if features.ndim == 1:
                    features = features.reshape(1, -1)
                
                a, b = features[n1], features[n2]
                intersection = np.count_nonzero(a * b)
                union = np.count_nonzero(a) + np.count_nonzero(b) - intersection
                
                if union == 0:
                    J = 0
                else:
                    J = intersection / union
                
                if J < threshold:
                    data[i] = 0
                    removed_cnt += 1
        
        return removed_cnt

    def _dropedge_cosine(self, adj_triu, features, threshold):
        """Remove edges based on cosine similarity"""
        removed_cnt = 0
        data, indices, indptr = adj_triu.data, adj_triu.indices, adj_triu.indptr
        
        for row in range(len(indptr) - 1):
            for i in range(indptr[row], indptr[row + 1]):
                n1 = row
                n2 = indices[i]
                
                # Ensure features are 2D
                if features.ndim == 1:
                    features = features.reshape(1, -1)
                
                a, b = features[n1], features[n2]
                inner_product = np.dot(a, b)
                norm_a = np.linalg.norm(a)
                norm_b = np.linalg.norm(b)
                
                if norm_a == 0 or norm_b == 0:
                    C = 0
                else:
                    C = inner_product / (norm_a * norm_b)
                
                if C <= threshold:
                    data[i] = 0
                    removed_cnt += 1
        
        return removed_cnt

    def _augment_adj(self, adj, adj_deleted, recover_percent):
        """Augment adjacency matrix by recovering some deleted edges"""
        adj_delete_tril = sp.tril(adj_deleted)
        row_idx, col_idx = adj_delete_tril.nonzero()
        
        if len(row_idx) == 0:
            return adj.copy()
        
        edge_num = len(row_idx)
        add_edge_num = int(edge_num * recover_percent)
        
        if add_edge_num > 0:
            add_idx = np.random.choice(edge_num, add_edge_num, replace=False)
            aug_adj = adj.copy().tolil()
            
            for i in add_idx:
                aug_adj[row_idx[i], col_idx[i]] = 1
                aug_adj[col_idx[i], row_idx[i]] = 1
            
            return aug_adj.tocsr()
        
        return adj.copy()
    
    def _normalize_adj_sparse(self, adj):
        """Normalize sparse adjacency matrix"""
        adj = sp.coo_matrix(adj)
        rowsum = np.array(adj.sum(1))
        d_inv_sqrt = np.power(rowsum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
    
    def _add_reliable_neighbors(self, adj, embeddings):
        """Add reliable neighbors based on cosine similarity"""
        degree = adj.sum(dim=1)
        degree_mask = degree > self.degree_threshold
        
        if degree_mask.sum().item() >= self.k_neighbors:
            sim = cosine_similarity(embeddings.cpu().numpy())
            sim = torch.FloatTensor(sim).to(self.device)
            sim[:, ~degree_mask] = 0
            
            _, top_k_indices = sim.topk(k=self.k_neighbors, dim=1)
            
            for i in range(adj.shape[0]):
                adj[i][top_k_indices[i]] = 1
                adj[i][i] = 0  # Remove self-loops
    
    def _normalize_adj(self, adj, alpha):
        """Normalize adjacency matrix with parameter alpha"""
        adj = adj + torch.eye(adj.shape[0]).to(self.device)
        degree = adj.sum(dim=1)
        
        in_degree_norm = torch.pow(degree.view(1, -1), alpha).expand(adj.shape[0], adj.shape[0])
        out_degree_norm = torch.pow(degree.view(-1, 1), alpha).expand(adj.shape[0], adj.shape[0])
        
        adj = adj * in_degree_norm * out_degree_norm
        
        if alpha != -0.5:
            row_sum = adj.sum(dim=1, keepdim=True)
            adj = adj / (row_sum + 1e-12)
        
        return adj

    def fit_with_trainer(self, trainer, data, train_mask, val_mask, epochs=200):
        """Fit method compatible with Trainer for structure learning models"""
        # Convert data to appropriate format
        x, y = data.x, data.y
        if hasattr(data, 'edge_index'):
            edge_index = data.edge_index
        else:
            raise ValueError("Data must have edge_index attribute")
        
        # Call the original fit method
        self.fit(x, edge_index, y, train_mask, val_mask)
        
        # Return self for chaining
        return self
    
    def get_learned_adj(self):
        """Get the learned adjacency matrix"""
        return self.cleaned_adj