import time
import warnings
from copy import deepcopy
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.optim.optimizer import Optimizer, required
from torch_geometric.typing import Adj, OptTensor

from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.utils import wrapper


class ProxOperators:
    """Proximal Operators for ProGNN"""
    
    def __init__(self):
        self.nuclear_norm = None
    
    def prox_l1(self, data: Tensor, alpha: float) -> Tensor:
        """Proximal operator for l1 norm"""
        return torch.mul(torch.sign(data), 
                        torch.clamp(torch.abs(data) - alpha, min=0))
    
    def prox_nuclear_cuda(self, data: Tensor, alpha: float) -> Tensor:
        """Proximal operator for nuclear norm using CUDA-compatible SVD"""
        device = data.device
        U, S, V = torch.svd(data)
        self.nuclear_norm = S.sum()
        S = torch.clamp(S - alpha, min=0)
        indices = torch.tensor([range(0, U.shape[0]), range(0, U.shape[0])]).to(device)
        values = S
        diag_S = torch.sparse_coo_tensor(indices, values, torch.Size(U.shape), device=device)
        V = torch.spmm(diag_S, V.t_())
        V = torch.matmul(U, V)
        return V


class PGD(Optimizer):
    """Proximal Gradient Descent optimizer"""
    
    def __init__(self, params, proxs, alphas, lr=required, momentum=0, 
                 dampening=0, weight_decay=0):
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                       weight_decay=weight_decay, nesterov=False)
        super(PGD, self).__init__(params, defaults)
        
        for group in self.param_groups:
            group.setdefault('proxs', proxs)
            group.setdefault('alphas', alphas)
    
    def step(self, closure=None):
        for group in self.param_groups:
            lr = group['lr']
            proxs = group['proxs']
            alphas = group['alphas']
            
            for param in group['params']:
                for prox_operator, alpha in zip(proxs, alphas):
                    param.data = prox_operator(param.data, alpha=alpha*lr)


class EstimateAdj(nn.Module):
    """Learnable adjacency matrix estimation module"""
    
    def __init__(self, adj: Tensor, symmetric: bool = False, device: str = 'cpu'):
        super(EstimateAdj, self).__init__()
        n = adj.shape[0]
        self.estimated_adj = nn.Parameter(torch.FloatTensor(n, n))
        self._init_estimation(adj)
        self.symmetric = symmetric
        self.device = device
    
    def _init_estimation(self, adj: Tensor):
        with torch.no_grad():
            self.estimated_adj.data.copy_(adj)
    
    def forward(self) -> Tensor:
        return self.estimated_adj
    
    def normalize(self) -> Tensor:
        """Normalize the estimated adjacency matrix"""
        if self.symmetric:
            adj = (self.estimated_adj + self.estimated_adj.t()) / 2
        else:
            adj = self.estimated_adj
        
        normalized_adj = self._normalize(adj + torch.eye(adj.shape[0]).to(self.device))
        return normalized_adj
    
    def _normalize(self, mx: Tensor) -> Tensor:
        """Symmetric normalization"""
        rowsum = mx.sum(1)
        r_inv = rowsum.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        mx = r_mat_inv @ mx @ r_mat_inv
        return mx


class ProGNN(nn.Module):
    """ProGNN (Properties Graph Neural Network) implementation for GreatX
    
    Parameters
    ----------
    in_channels : int
        the input dimensions of model
    out_channels : int
        the output dimensions of model
    hids : List[int], optional
        the number of hidden units for each hidden layer, by default [16]
    acts : List[str], optional
        the activation function for each hidden layer, by default ['relu']
    dropout : float, optional
        the dropout ratio of model, by default 0.5
    bias : bool, optional
        whether to use bias in the layers, by default True
    symmetric : bool, optional
        whether to enforce symmetric adjacency matrix, by default False
    lr : float, optional
        learning rate for GNN parameters, by default 0.01
    lr_adj : float, optional
        learning rate for adjacency matrix, by default 0.01
    alpha : float, optional
        coefficient for L1 penalty, by default 5e-4
    beta : float, optional
        coefficient for nuclear norm penalty, by default 1.5
    gamma : float, optional
        coefficient for GCN loss, by default 1.0
    lambda_ : float, optional
        coefficient for feature smoothing, by default 0.0
    phi : float, optional
        coefficient for symmetry penalty, by default 0.0
    epochs : int, optional
        number of training epochs, by default 200
    inner_steps : int, optional
        number of inner steps for GCN training, by default 2
    outer_steps : int, optional
        number of outer steps for adjacency training, by default 1
    device : str, optional
        device to use, by default 'cpu'
        
    Examples
    --------
    >>> # ProGNN with default parameters
    >>> model = ProGNN(100, 10)
    
    >>> # ProGNN with custom architecture
    >>> model = ProGNN(100, 10, hids=[32, 16], acts=['relu', 'elu'])
    """
    
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [128], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True,
                 symmetric: bool = False, lr: float = 0.001, lr_adj: float = 0.01,
                 alpha: float = 5e-4, beta: float = 1.5, gamma: float = 1.0,
                 lambda_: float = 0.001, phi: float = 0.0, epochs: int = 200,
                 inner_steps: int = 2, outer_steps: int = 1, device: str = 'cpu'):
        
        super().__init__()
        
        # Model parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hids = hids
        self.acts = acts
        self.dropout = dropout
        self.bias = bias
        self.symmetric = symmetric
        self.device = device
        
        # Training parameters
        self.lr = lr
        self.lr_adj = lr_adj
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.lambda_ = lambda_
        self.phi = phi
        self.epochs = epochs
        self.inner_steps = inner_steps
        self.outer_steps = outer_steps
        
        # Build GCN backbone
        self._build_gcn()
        
        # Training state
        self.best_val_acc = 0
        self.best_val_loss = 10
        self.best_graph = None
        self.weights = None
        self.estimator = None
        self.prox_operators = ProxOperators()
        
    def _build_gcn(self):
        """Build GCN backbone following GreatX style"""
        conv = []
        assert len(self.hids) == len(self.acts)
        
        in_channels = self.in_channels
        for hid, act in zip(self.hids, self.acts):
            conv.append(GCNConv(in_channels, hid, bias=self.bias))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(self.dropout))
            in_channels = hid
            
        conv.append(GCNConv(in_channels, self.out_channels, bias=self.bias))
        self.gcn = Sequential(*conv)
        
    def reset_parameters(self):
        """Reset model parameters"""
        self.gcn.reset_parameters()
        if self.estimator is not None:
            self.estimator._init_estimation(self.estimator.estimated_adj.data)
    
    def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
        """Forward pass - if estimator is available, use learned adjacency"""
        if self.estimator is not None:
            # Use learned adjacency matrix
            adj = self.estimator.normalize()
            return self.gcn(x, adj, edge_weight)
        else:
            # Use original adjacency matrix - convert edge_index to dense if needed
            if edge_index.dtype == torch.long:
                # Convert sparse to dense
                num_nodes = x.shape[0]
                adj = torch.zeros(num_nodes, num_nodes, device=self.device)
                adj[edge_index[0], edge_index[1]] = 1.0
                if edge_weight is not None:
                    adj[edge_index[0], edge_index[1]] = edge_weight
                # Add self-loops and normalize
                adj = adj + torch.eye(num_nodes, device=self.device)
                # Symmetric normalization
                rowsum = adj.sum(1)
                r_inv = rowsum.pow(-1/2).flatten()
                r_inv[torch.isinf(r_inv)] = 0.
                r_mat_inv = torch.diag(r_inv)
                adj = r_mat_inv @ adj @ r_mat_inv
                return self.gcn(x, adj, edge_weight)
            else:
                # Dense adjacency matrix
                adj = edge_index.to(self.device)
                return self.gcn(x, adj, edge_weight)
    
    def fit(self, x: Tensor, edge_index: Adj, y: Tensor, 
            train_mask: Tensor, val_mask: Tensor, 
            edge_weight: OptTensor = None, **kwargs):
        """Train ProGNN model
        
        Parameters
        ----------
        x : Tensor
            node features
        edge_index : Adj
            adjacency matrix (dense tensor format)
        y : Tensor
            node labels
        train_mask : Tensor
            training mask
        val_mask : Tensor
            validation mask
        edge_weight : OptTensor, optional
            edge weights, by default None
        """
        print(f"=== ProGNN Training Started ===")
        print(f"Device: {self.device}")
        print(f"Epochs: {self.epochs}, Inner steps: {self.inner_steps}, Outer steps: {self.outer_steps}")
        print(f"Learning rates - GCN: {self.lr}, Adj: {self.lr_adj}")
        print(f"Regularization - Alpha: {self.alpha}, Beta: {self.beta}, Gamma: {self.gamma}, Lambda: {self.lambda_}, Phi: {self.phi}")
        
        # Convert edge_index to dense adjacency matrix if needed
        if edge_index.dtype == torch.long:
            # Convert sparse to dense
            num_nodes = x.shape[0]
            adj = torch.zeros(num_nodes, num_nodes, device=self.device)
            adj[edge_index[0], edge_index[1]] = 1.0
            if edge_weight is not None:
                adj[edge_index[0], edge_index[1]] = edge_weight
            print(f"Converted sparse edge_index to dense adjacency matrix: {adj.shape}")
        else:
            adj = edge_index.to(self.device)
            print(f"Using dense adjacency matrix: {adj.shape}")
        
        # Initialize estimator
        self.estimator = EstimateAdj(adj, symmetric=self.symmetric, device=self.device)
        self.estimator = self.estimator.to(self.device)
        print(f"Initialized adjacency estimator with symmetric={self.symmetric}")
        
        # Initialize optimizers
        self.optimizer = optim.Adam(self.gcn.parameters(), lr=self.lr)
        self.optimizer_adj = optim.SGD(self.estimator.parameters(), 
                                      momentum=0.9, lr=self.lr_adj)
        
        # Proximal operators
        self.optimizer_l1 = PGD(self.estimator.parameters(),
                               proxs=[self.prox_operators.prox_l1],
                               alphas=[self.alpha], lr=self.lr_adj)
        
        self.optimizer_nuclear = PGD(self.estimator.parameters(),
                                   proxs=[self.prox_operators.prox_nuclear_cuda],
                                   alphas=[self.beta], lr=self.lr_adj)
        
        print(f"Initialized optimizers: Adam (GCN), SGD (Adj), PGD (L1), PGD (Nuclear)")
        
        # Move to device
        self.to(self.device)
        x, y = x.to(self.device), y.to(self.device)
        train_mask, val_mask = train_mask.to(self.device), val_mask.to(self.device)
        
        print(f"Data shapes - X: {x.shape}, Y: {y.shape}")
        print(f"Train nodes: {train_mask.sum().item()}, Val nodes: {val_mask.sum().item()}")
        
        # Training loop
        t_total = time.time()
        for epoch in range(self.epochs):
            print(f"\n--- Epoch {epoch + 1}/{self.epochs} ---")
            
            # Train adjacency matrix
            for step in range(self.outer_steps):
                print(f"  Adj training step {step + 1}/{self.outer_steps}")
                self._train_adj(epoch, x, adj, y, train_mask, val_mask)
            
            # Train GCN
            for step in range(self.inner_steps):
                print(f"  GCN training step {step + 1}/{self.inner_steps}")
                self._train_gcn(epoch, x, y, train_mask, val_mask)
            
            # Print epoch summary
            if epoch % 10 == 0 or epoch == self.epochs - 1:
                print(f"  Best val acc: {self.best_val_acc:.4f}, Best val loss: {self.best_val_loss:.4f}")
        
        total_time = time.time() - t_total
        print(f"\n=== ProGNN Training Completed ===")
        print(f"Total time elapsed: {total_time:.4f}s")
        print(f"Final best validation accuracy: {self.best_val_acc:.4f}")
        print(f"Final best validation loss: {self.best_val_loss:.4f}")
        
        # Load best model
        if self.weights is not None:
            self.gcn.load_state_dict(self.weights)
            print("Loaded best model weights")
        else:
            print("Warning: No best weights found, using current model")
    
    def _train_gcn(self, epoch: int, x: Tensor, y: Tensor, 
                   train_mask: Tensor, val_mask: Tensor):
        """Train GCN parameters"""
        self.train()
        self.optimizer.zero_grad()
        
        # Forward pass
        normalized_adj = self.estimator.normalize()
        output = self.gcn(x, normalized_adj)
        
        # Compute loss - use NLL loss to match DeepRobust implementation
        # Apply log_softmax to get log probabilities
        log_probs = F.log_softmax(output, dim=1)
        loss_train = F.nll_loss(log_probs[train_mask], y[train_mask])
        
        # Compute training accuracy
        pred_train = output[train_mask].argmax(dim=1)
        acc_train = (pred_train == y[train_mask]).float().mean()
        
        loss_train.backward()
        self.optimizer.step()
        
        # Validation
        self.eval()
        with torch.no_grad():
            output = self.gcn(x, normalized_adj)
            log_probs = F.log_softmax(output, dim=1)
            loss_val = F.nll_loss(log_probs[val_mask], y[val_mask])
            pred_val = output[val_mask].argmax(dim=1)
            acc_val = (pred_val == y[val_mask]).float().mean()
        
        # Print training progress
        if epoch % 10 == 0:
            print(f"    GCN - Train loss: {loss_train.item():.4f}, Train acc: {acc_train.item():.4f}")
            print(f"    GCN - Val loss: {loss_val.item():.4f}, Val acc: {acc_val.item():.4f}")
        
        # Save best model
        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.gcn.state_dict())
            if epoch % 10 == 0:
                print(f"    *** New best val acc: {self.best_val_acc:.4f} ***")
            
        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.gcn.state_dict())
            if epoch % 10 == 0:
                print(f"    *** New best val loss: {self.best_val_loss:.4f} ***")
    
    def _train_adj(self, epoch: int, x: Tensor, adj: Tensor, y: Tensor,
                   train_mask: Tensor, val_mask: Tensor):
        """Train adjacency matrix"""
        self.estimator.train()
        self.optimizer_adj.zero_grad()
        
        # Compute losses
        loss_l1 = torch.norm(self.estimator.estimated_adj, 1)
        loss_fro = torch.norm(self.estimator.estimated_adj - adj, p='fro')
        loss_symmetric = torch.norm(self.estimator.estimated_adj - 
                                   self.estimator.estimated_adj.t(), p='fro')
        
        # Feature smoothing loss
        if self.lambda_ > 0:
            loss_smooth_feat = self._feature_smoothing(self.estimator.estimated_adj, x)
        else:
            loss_smooth_feat = 0 * loss_l1
        
        # GCN loss - use NLL loss to match DeepRobust implementation
        normalized_adj = self.estimator.normalize()
        output = self.gcn(x, normalized_adj)
        log_probs = F.log_softmax(output, dim=1)
        loss_gcn = F.nll_loss(log_probs[train_mask], y[train_mask])
        
        # Compute training accuracy
        pred_train = output[train_mask].argmax(dim=1)
        acc_train = (pred_train == y[train_mask]).float().mean()
        
        # Total loss
        loss_total = (loss_fro + self.gamma * loss_gcn + 
                     self.lambda_ * loss_smooth_feat + 
                     self.phi * loss_symmetric)
        
        loss_total.backward()
        self.optimizer_adj.step()
        
        # Apply proximal operators
        loss_nuclear = 0.0
        if self.beta != 0:
            self.optimizer_nuclear.zero_grad()
            self.optimizer_nuclear.step()
            if hasattr(self.prox_operators, 'nuclear_norm') and self.prox_operators.nuclear_norm is not None:
                loss_nuclear = self.prox_operators.nuclear_norm.item()
        
        self.optimizer_l1.zero_grad()
        self.optimizer_l1.step()
        
        # Clamp adjacency matrix
        self.estimator.estimated_adj.data.clamp_(min=0, max=1)
        
        # Validation
        self.eval()
        with torch.no_grad():
            normalized_adj = self.estimator.normalize()
            output = self.gcn(x, normalized_adj)
            log_probs = F.log_softmax(output, dim=1)
            loss_val = F.nll_loss(log_probs[val_mask], y[val_mask])
            pred_val = output[val_mask].argmax(dim=1)
            acc_val = (pred_val == y[val_mask]).float().mean()
        
        # Print detailed loss information
        if epoch % 10 == 0:
            print(f"    Adj - Train acc: {acc_train.item():.4f}, Val acc: {acc_val.item():.4f}, Val loss: {loss_val.item():.4f}")
            print(f"    Adj - Loss components:")
            print(f"      Frobenius: {loss_fro.item():.4f}, GCN: {loss_gcn.item():.4f}")
            print(f"      L1: {loss_l1.item():.4f}, Nuclear: {loss_nuclear:.4f}")
            print(f"      Smooth feat: {loss_smooth_feat.item():.4f}, Symmetric: {loss_symmetric.item():.4f}")
            print(f"      Total: {loss_total.item():.4f}")
        
        # Save best model
        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.gcn.state_dict())
            if epoch % 10 == 0:
                print(f"    *** New best val acc: {self.best_val_acc:.4f} ***")
            
        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.gcn.state_dict())
            if epoch % 10 == 0:
                print(f"    *** New best val loss: {self.best_val_loss:.4f} ***")
    
    def _feature_smoothing(self, adj: Tensor, x: Tensor) -> Tensor:
        """Compute feature smoothing loss"""
        adj = (adj.t() + adj) / 2
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        L = D - adj
        
        r_inv = r_inv + 1e-3
        r_inv = r_inv.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        L = r_mat_inv @ L @ r_mat_inv
        
        XLXT = torch.matmul(torch.matmul(x.t(), L), x)
        loss_smooth_feat = torch.trace(XLXT)
        return loss_smooth_feat
    
    def test(self, x: Tensor, y: Tensor, test_mask: Tensor) -> float:
        """Test the model performance"""
        self.eval()
        with torch.no_grad():
            if self.best_graph is not None:
                output = self.gcn(x, self.best_graph)
            else:
                adj = self.estimator.normalize()
                output = self.gcn(x, adj)
            
            pred = output[test_mask].argmax(dim=1)
            acc = (pred == y[test_mask]).float().mean()
            return acc.item()
    
    def fit_with_trainer(self, trainer, data, train_mask, val_mask, epochs=200):
        """Fit method compatible with Trainer for structure learning models"""
        # Convert data to appropriate format
        x, y = data.x, data.y
        if hasattr(data, 'edge_index'):
            edge_index = data.edge_index
        else:
            raise ValueError("Data must have edge_index attribute")
        
        # Call the original fit method
        self.fit(x, edge_index, y, train_mask, val_mask)
        
        # Return self for chaining
        return self
    
    def get_learned_adj(self):
        """Get the learned adjacency matrix"""
        if self.estimator is not None:
            return self.estimator.normalize()
        else:
            return None
