import torch.nn as nn
import torch
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_geometric.utils import erdos_renyi_graph, remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from data_utils import sys_normalized_adjacency, sparse_mx_to_torch_sparse_tensor
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn import global_mean_pool
from improved_filter import ImprovedCausalFilter

def gcn_conv(x, edge_index):
    N = x.shape[0]
    row, col = edge_index
    d = degree(col, N).float()
    d_norm_in = (1. / d[col]).sqrt()
    d_norm_out = (1. / d[row]).sqrt()
    value = torch.ones_like(row) * d_norm_in * d_norm_out
    value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
    adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
    return matmul(adj, x) 

class GraphConvolutionBase(nn.Module):

    def __init__(self, in_features, out_features, residual=False):
        super(GraphConvolutionBase, self).__init__()
        self.residual = residual
        self.in_features = in_features

        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(self.in_features, self.out_features))
        if self.residual:
            self.weight_r = Parameter(torch.FloatTensor(self.in_features, self.out_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_features)
        self.weight.data.uniform_(-stdv, stdv)
        self.weight_r.data.uniform_(-stdv, stdv)

    def forward(self, x, adj, x0):
        hi = gcn_conv(x, adj)
        output = torch.mm(hi, self.weight)
        if self.residual:
            output = output + torch.mm(x, self.weight_r)
        return output

class CaNetConv(nn.Module):

    def __init__(self, in_features, out_features, K, residual=True,
                 backbone_type='gcn', variant=False, device=None):
        super(CaNetConv, self).__init__()
        self.backbone_type = backbone_type
        self.out_features = out_features
        self.residual = residual
        if backbone_type == 'gcn':
            self.weights = Parameter(torch.FloatTensor(K, in_features * 2, out_features))
        elif backbone_type == 'gat':
            self.leakyrelu = nn.LeakyReLU()
            self.weights = nn.Parameter(torch.zeros(K, in_features, out_features))
            self.a = nn.Parameter(torch.zeros(K, 2 * out_features, 1))
        self.K = K
        self.device = device
        self.variant = variant
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.out_features)
        self.weights.data.uniform_(-stdv, stdv)
        if self.backbone_type == 'gat':
            nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def specialspmm(self, adj, spm, size, h):
        adj = SparseTensor(row=adj[0], col=adj[1],
                           value=spm, sparse_sizes=size)
        return matmul(adj, h)

    def forward(self, x, adj, e, weights=None):
        if weights is None:
            weights = self.weights

        if self.backbone_type == 'gcn':
            if not self.variant:
                hi = gcn_conv(x, adj)
            else:
                adj = torch.sparse_coo_tensor(
                    adj,
                    torch.ones(adj.shape[1], device=self.device),
                    size=(x.size(0), x.size(0))
                ).to(self.device)
                hi = torch.sparse.mm(adj, x)
            hi = torch.cat([hi, x], dim=1)
            hi = hi.unsqueeze(0).repeat(self.K, 1, 1)   
            outputs = torch.matmul(hi, weights)        
            outputs = outputs.transpose(1, 0)          

        elif self.backbone_type == 'gat':
            xi = x.unsqueeze(0).repeat(self.K, 1, 1)    
            h_lin = torch.matmul(xi, weights)          
            N = x.size(0)
            adj, _ = remove_self_loops(adj)
            adj, _ = add_self_loops(adj, num_nodes=N)
            edge_h = torch.cat(
                (h_lin[:, adj[0, :], :], h_lin[:, adj[1, :], :]),
                dim=2
            )                                          
            logits = self.leakyrelu(torch.matmul(edge_h, self.a)).squeeze(2)
            logits_max, _ = torch.max(logits, dim=1, keepdim=True)
            edge_e = torch.exp(logits - logits_max)    

            outputs = []
            eps = 1e-8
            for k in range(self.K):
                edge_e_k = edge_e[k]                   
                norm = self.specialspmm(adj, edge_e_k,
                                        torch.Size([N, N]),
                                        torch.ones(N, 1, device=self.device)
                                       ) + eps
                hi_k = self.specialspmm(adj, edge_e_k,
                                        torch.Size([N, N]),
                                        h_lin[k]
                                       )
                hi_k = hi_k / norm
                outputs.append(hi_k)                   
            outputs = torch.stack(outputs, dim=1)      

        
        es = e.unsqueeze(2).repeat(1, 1, self.out_features)  
        out = torch.sum(es * outputs, dim=1)                 

        if self.residual:
            out = out + x

        return out

class CaNet(nn.Module):
    def __init__(self, d, c, args, device):
        super(CaNet, self).__init__()
        self.dropout     = args.dropout
        self.act_fn      = nn.ReLU()
        self.num_layers  = args.num_layers
        self.tau         = args.tau
        self.env_type    = args.env_type
        self.use_causal_filter = args.use_causal_filter
        
        
        if self.use_causal_filter:
            self.input_filter = ImprovedCausalFilter(
                input_dim=d,
                lambda_init=args.filter_lambda_init,
                lambda_min=args.filter_lambda_min,
                decay_rate=args.filter_decay,
                temperature=args.filter_temp,
                residual_weight=args.filter_residual,
                normalize=False,
                dropout=0.1
            )
            self.hidden_filters = nn.ModuleList()
            for _ in range(args.num_layers):
                self.hidden_filters.append(ImprovedCausalFilter(
                    input_dim=args.hidden_channels,
                    lambda_init=args.filter_lambda_init,
                    lambda_min=args.filter_lambda_min,
                    decay_rate=args.filter_decay,
                    temperature=args.filter_temp,
                    residual_weight=args.filter_residual,
                    normalize=False,
                    dropout=0.1
                ))
        
        self.convs = nn.ModuleList()
        for _ in range(args.num_layers):
            self.convs.append(CaNetConv(args.hidden_channels, args.hidden_channels, args.K, backbone_type=args.backbone_type, residual=True, device=device, variant=args.variant))
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(d, args.hidden_channels))
        self.fcs.append(nn.Linear(args.hidden_channels, c))
        self.env_enc = nn.ModuleList()
        for _ in range(args.num_layers):
            if args.env_type == 'node':
                self.env_enc.append(nn.Linear(args.hidden_channels, args.K))
            elif args.env_type == 'graph':
                self.env_enc.append(GraphConvolutionBase(args.hidden_channels, args.K, residual=True))
            else:
                raise NotImplementedError
        self.act_fn = nn.ReLU()
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.tau = args.tau
        self.env_type = args.env_type
        self.device = device

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()
        for enc in self.env_enc:
            enc.reset_parameters()
        

    def forward(self, x, edge_index, batch=None, training=False):
        self.training = training
        
        
        if self.use_causal_filter:
            x = self.input_filter(x)
        
        
        x = F.dropout(x, self.dropout, training=self.training)
        h = self.act_fn(self.fcs[0](x))
        h0 = h.clone()
        reg = 0.0

        
        for i, conv in enumerate(self.convs):
            h = F.dropout(h, self.dropout, training=self.training)
            
            
            if self.use_causal_filter:
                h = self.hidden_filters[i](h)
            
            if self.training:
                if self.env_type == 'node':
                    logit = self.env_enc[i](h)
                else:
                    logit = self.env_enc[i](h, edge_index, h0)
                e = F.gumbel_softmax(logit, tau=self.tau, dim=-1)
                reg += self.reg_loss(e, logit)
            else:
                if self.env_type == 'node':
                    e = F.softmax(self.env_enc[i](h), dim=-1)
                else:
                    e = F.softmax(self.env_enc[i](h, edge_index, h0), dim=-1)
            h = self.act_fn(conv(h, edge_index, e))

        
        h = F.dropout(h, self.dropout, training=self.training)

        
        if batch is not None:
            graph_embeddings = global_mean_pool(h, batch)
            out = self.fcs[-1](graph_embeddings)
        else:
            out = self.fcs[-1](h)

        if self.training:
            return out, reg / self.num_layers
        else:
            return out

    def reg_loss(self, z, logit, logit_0 = None):
        log_pi = logit - torch.logsumexp(logit, dim=-1, keepdim=True).repeat(1, logit.size(1))
        return torch.mean(torch.sum(
            torch.mul(z, log_pi), dim=1))

    def sup_loss_calc(self, y, pred, criterion, args):
        if args.dataset in ('twitch', 'elliptic'):
            if y.shape[1] == 1:
                true_label = F.one_hot(y, y.max() + 1).squeeze(1)
            else:
                true_label = y
            loss = criterion(pred, true_label.squeeze(1).to(torch.float))
        else:
            out = F.log_softmax(pred, dim=1)
            target = y.squeeze(1)
            loss = criterion(out, target)
        return loss

    def loss_compute(self, d, criterion, args):
        logits, reg_loss = self.forward(d.x, d.edge_index, training=True)
        sup_loss = self.sup_loss_calc(d.y[d.train_idx], logits[d.train_idx], criterion, args)
        loss = sup_loss + args.lamda * reg_loss
        return loss

    def step_epoch(self):
        
        if self.use_causal_filter:
            self.input_filter.step()
            for filter_module in self.hidden_filters:
                filter_module.step()

    def get_filter_info(self):
        
        if not self.use_causal_filter:
            return "Causal filters are disabled."
        
        info = []
        
        stats = self.input_filter.get_stats()
        info.append(f"Input Filter: λ={stats['lambda']:.4f}, Gate Stats={stats['gate_stats']}")
        
        
        for i, filter_module in enumerate(self.hidden_filters):
            stats = filter_module.get_stats()
            info.append(f"Hidden Filter {i}: λ={stats['lambda']:.4f}, Gate Stats={stats['gate_stats']}")
        
        return "\n".join(info)