import math
from typing import Callable
import os
import random

import torch
import torch.nn as nn
from torch.nn.functional import relu, gelu, leaky_relu, elu, silu, softmax, tanh

from torch.functional import F

import numpy as np

from torch_geometric.data import Data
import networkx as nx

def init_param(module, n_layers):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.02, std=0.02 / math.sqrt(n_layers))
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.2)

def get_activation_fn(activation: str) -> Callable:
    if activation == "relu":
        return relu
    elif activation == "gelu":
        return gelu
    elif activation == "leaky_relu":
        return leaky_relu
    elif activation == "elu":
        return elu
    elif activation == "silu":
        return silu
    elif activation == 'softmax':
        return softmax
    elif activation == 'tanh':
        return tanh
    else:
        raise ValueError(f"Unsupported activation function: {activation}")

def check_nan_inf(tensor, name, *args):

    assert torch.any(torch.isnan(tensor)) == False, f"{name} contains nan, {args}"
    assert torch.any(torch.isinf(tensor)) == False, f"{name} contains inf, {args}"

def get_padded_features(batch, device, orgfeat):
    num_nodes = torch.bincount(batch)
    max_num_nodes = num_nodes.max().item()
    node_slice = torch.cat((torch.tensor([0], device=num_nodes.device), num_nodes.cumsum(0)))
    
    # 预先分配内存，避免频繁的内存重新分配和连接操作
    batch_size = len(num_nodes)
    padded_feats = torch.zeros(batch_size, max_num_nodes, orgfeat.shape[1] + 1, device=device)
    padding_flags = torch.zeros(batch_size, max_num_nodes, dtype=torch.bool, device=device)
    
    for i in range(batch_size):
        num_node = num_nodes[i].item()
        start_idx = node_slice[i]
        end_idx = node_slice[i+1]
        node_feature = orgfeat[start_idx:end_idx]
        
        # 设置填充标志
        padding_flags[i, :num_node] = True
        
        # 设置特征
        padded_feats[i, :num_node, 0] = 1  # 标记非填充位置
        padded_feats[i, :num_node, 1:] = node_feature
    
    return padding_flags, padded_feats


def cal_mmd(atoms, x, node_silce, gamma):
    batch_size = node_silce.shape[0] - 1
    xx = []
    xy = []
    for i in range(batch_size):
        start_idx = node_silce[i]
        end_idx = node_silce[i+1]
        x_i = x[start_idx:end_idx]
        # d_xy = ((atoms[:, :, np.newaxis, :] - x_i[np.newaxis, np.newaxis, :, :])**2).sum(axis=-1)
        d_xy = torch.cdist(atoms, x_i.repeat(atoms.shape[0], 1, 1), p=2)**2
        # d_xy = []
        # for idx_atm in range(atoms.shape[0]):
        #     d_xy.append(torch.cdist(atoms[idx_atm], x_i)**2)    # [n_supp, n_nodes]
        # d_xy = torch.stack(d_xy, dim=0) # [n_atoms, n_supp, n_nodes]
        k_xy = neg_exp(d_xy, gamma)
        # check_nan_inf(k_xy, "k_xy", gamma)
        xy.append(k_xy.mean(dim=(1, 2)))

        d_xx = ((x_i[:, np.newaxis, :] - x_i)**2).sum(axis=-1)
        k_xx = neg_exp(d_xx, gamma)
        xx.append(k_xx.mean())
    xy = torch.stack(xy, dim=0).transpose(0, 1) # [n_atoms, batch_size]
    # check_nan_inf(xy, "xy")
    xx = torch.stack(xx, dim=0) # [batch_size,]
    # check_nan_inf(xx, "xx")
    # d_yy = []
    # for idx_atm in range(atoms.shape[0]):
    #     d_yy.append(torch.cdist(atoms[idx_atm], atoms[idx_atm], p=2)**2)    # [n_supp, n_supp]
        
    # d_yy = torch.stack(d_yy, dim=0) # [n_atoms, n_supp, n_supp]
    d_yy = ((atoms[:, :, np.newaxis, :] - atoms[:, np.newaxis, :, :])**2).sum(axis=-1)
    # check_nan_inf(d_yy, "d_yy")
    k_yy = neg_exp(d_yy, gamma)
    yy = k_yy.mean(dim=(1, 2))  # [n_atoms, ]
    # check_nan_inf(yy, "yy")
    mmd_distance = (yy.unsqueeze(-1) + xx - 2 * xy).T
    # check_nan_inf(mmd_distance, "MMD distance")
    mmd_distance = mmd_distance.clamp(1e-7)
    mmd_distance = mmd_distance ** 0.5
    # check_nan_inf(mmd_distance, "MMD distance")
    return mmd_distance

def neg_exp(dist, gamma):
    return torch.exp(-dist/gamma)


class MLP(nn.Module):
    def __init__(
        self, 
        in_dim: int, 
        out_dim: int, 
        num_layers: int = 2, 
        hidden_dim: int = 32, 
        activation: str = "relu", 
        **kwargs
    ):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_layers = max(1, num_layers)  # ensure at least 1 layer
        self.hidden_dim = hidden_dim

        self.linear_layers = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.activation = get_activation_fn(activation)
        self.activation_args = kwargs
        
        if self.num_layers == 1:
            # single layer
            self.linear_layers.append(nn.Linear(in_dim, out_dim))
        else:
            # multi-layer
            self.linear_layers.append(nn.Linear(in_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
            
            # add middle layers
            for layer in range(num_layers - 2):
                self.linear_layers.append(nn.Linear(hidden_dim, hidden_dim))
                self.bns.append(nn.BatchNorm1d(hidden_dim))
            
            # last layer
            self.linear_layers.append(nn.Linear(hidden_dim, out_dim))

    def forward(self, x):
        if self.num_layers == 1:
            # single layer
            return self.linear_layers[0](x)
        else:
            # multi-layer
            for i in range(self.num_layers - 1):
                x = self.linear_layers[i](x)
                if x.shape[0] > 1:
                    x = self.bns[i](x)
                x = self.activation(x, **self.activation_args)
            x = self.linear_layers[-1](x)
            return x
            
    def reset_parameters(self):
        for layer in self.linear_layers:
            init_param(layer, self.num_layers)

    
def elimitate_padding(x, padding_flag):
    batch_size = x.shape[0]
    result = []
    
    for i in range(batch_size):
        valid_nodes = x[i][padding_flag[i]]
        result.append(valid_nodes)
    
    return torch.cat(result, dim=0)

def get_scale_Mu(dataset, len_scale, dim):
    Mu = torch.empty(len_scale, dim)
    for j in range(len_scale):
        Mu[j, :] = dataset.x[:, j*dim:(j+1)*dim].mean(dim=0)
    return Mu

def modify_graph_feat(dataset, node_features):
    mod_dataset = []
    start = 0
    for data in dataset:
        end = start + data.x.shape[0]
        mod_dataset.append(Data(node_features[start:end, :], edge_index=data.edge_index, y=data.y))
        start = end
    return mod_dataset
        