import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
import time
import psutil
import os


class ClassifierModel(nn.Module):
    def __init__(self, dataset_name, n, feat_dim, hidden_dim, output_dim, dropout, walk_len, task_level, backbone, device, args): 
        super(ClassifierModel, self).__init__()
        self.n = n
        self.walk_len = walk_len
        self.adj = None
        self.query_edges = None
        self.dataset_name = dataset_name
        self.edge_index = None
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.backbone = backbone
        self.K=5
        self.convs = nn.ModuleList()
        self.num_layers = args.num_layers
        self.num_heads = args.num_heads
        if self.backbone == "mlp":
            self.gnn2 = LRW_MLP(feat_dim*(self.walk_len+2), hidden_dim, hidden_dim, self.num_layers, dropout)
            
        elif self.backbone == "gcn":
            self.gnn2 = GCN(in_channels=feat_dim*(self.walk_len+2),
                            hidden_channels=hidden_dim,
                            out_channels=hidden_dim,
                            num_layers=self.num_layers,
                            dropout=dropout,
                            use_bn=False)
            
        elif self.backbone == "gat":
            self.gnn2 = GAT(in_channels=feat_dim*(self.walk_len+2),
                           hidden_channels=hidden_dim,
                           out_channels=hidden_dim,
                           num_layers=self.num_layers,
                           dropout=dropout,
                           heads=self.num_heads)
        if task_level == "node":
            self.classification = nn.Linear(hidden_dim, output_dim)
        else:
            self.classification = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, feature, device):
        self.edge_index = self.edge_index.to(device)
        if self.backbone == "mlp":
            embedding = self.gnn2(feature)
            embedding = self.dropout(self.relu(embedding))
        
        elif self.backbone == "gcn":
            embedding = self.gnn2(feature, self.edge_index)
            
        elif self.backbone == "gat":
            embedding = self.gnn2(feature, self.edge_index)

        output = self.classification(embedding)
        return output
    
class LearnableRandomWalkModel(nn.Module):
    def __init__(self, n, feat_dim, hidden_dim, hidden_att_dim, output_dim, dropout, task_level, walk_len, tau, data_name): 
        super(LearnableRandomWalkModel, self).__init__()
        self.n = n
        self.graph_a = None
        self.graph_aa = None
        self.adj_a = None
        self.adj_aa = None
        self.query_edges = None
        self.tau = tau
        self.y = None
        self.path_feature = None
        self.mask = None
        self.walk_len = walk_len
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.hidden_dim = hidden_dim
        self.rw = None
        self.walks = None
        self.rw_grid = None
        self.label_grid = None
        self.joint_grid = None
        self.p_labels = None
        self.losses = None
        self.process = psutil.Process(os.getpid())
        self.data_name = data_name
        self.embedding = None
        
        self.gnn1 = LRW_MLP(feat_dim, hidden_dim, hidden_dim, 2, dropout)
        self.gnn2 = nn.Linear(self.hidden_dim, 1)
        self.cosine_similarity_matrix = None

    def generate_random_walk(self, idx, device):
        self.walks = None
        self.rw = None
        self.walks = None
        self.rw_grid = None
        self.label_grid = None
        self.joint_grid = None
        start = torch.tensor(idx, dtype=torch.int).reshape(1).to(device)
        walks = torch.Tensor().to(device)
        walks = torch.cat((walks, start))
        for _ in range(self.walk_len):
            node = walks[-1]
            neighbors = torch.tensor(list(self.graph_a.neighbors(node.item()))).to(device)
            neighbors = neighbors.int()
            if neighbors.numel() ==0 :
                break
            p_walk = torch.zeros(neighbors.size(0), dtype=torch.float).to(device)
            for id, i in enumerate(neighbors):
                if self.data_name == "arxivdir":
                    query_index = torch.tensor([int(node.item()), i]).to(device)
                    indices = self.cosine_similarity_matrix.indices()
                    values = self.cosine_similarity_matrix.values()
                    mask = (indices == query_index.view(2, 1)).all(dim=0)
                    p_walk[id] = values[mask]
                    del query_index
                    torch.cuda.empty_cache()
                else:
                    p_walk[id] = self.cosine_similarity_matrix[int(node.item()), i]
            p_walk = F.softmax(p_walk)
            dst = node
            tmps = torch.sum(walks==dst)
            while tmps<=1:
                dst = neighbors[torch.multinomial(p_walk, 1, replacement=True)]
                walks = torch.cat((walks, dst))
                break
            del tmps
            del dst
            del node
            del p_walk
            del neighbors
            torch.cuda.empty_cache()
        rw = self.gnn2(self.embedding[walks.long()]).reshape(len(walks))
        rw = F.softmax(rw / self.tau)
        self.walks = walks.detach().cpu().clone()
        del walks
        del start
        torch.cuda.empty_cache()
        return rw
    
    def kde(self, x, x_grid, bandwidth=1.0, n_points=1000):
        n = x.size(0)*x.size(1)
        diff = x.reshape(n, 1) - x_grid.reshape(1, n_points)
        kernel = torch.exp(-0.5 * (diff / bandwidth) ** 2)
        density = kernel.sum(dim=0) / (n * bandwidth * torch.sqrt(torch.tensor(2 * torch.pi)))
        return density
        
    def entropy(self, x, x_grid, bandwidth=1.0, n_points=1000):
        density = self.kde(x, x_grid, bandwidth, n_points)
        density_log = torch.log(density + 1e-10)  # Add small value to avoid log(0)
        entropy = -torch.sum(density * density_log) * (x.max() - x.min()) / n_points
        return entropy
    

    def loss(self, walks, rw, device):
        labels = self.y[walks.long()]
        labels = labels.to(device)
        self.device = device
        labels = labels.float().reshape(len(walks), 1)   
        rw = rw.reshape(len(walks), 1)     
        joint = torch.cat([rw,labels], dim=1).to(device)
        
        n_points = 10
        
        rw_grid = torch.linspace(rw.min().item(), rw.max().item(), n_points).reshape(-1, 1).to(device)
        joint_grid = torch.linspace(joint.min().item(), joint.max().item(), n_points).reshape(-1, 1).to(device)
        label_grid = torch.linspace(labels.min().item(), labels.max().item(), n_points).reshape(-1, 1).to(device)
        self.rw_grid = rw_grid
        self.joint_grid = joint_grid
        self.label_grid = label_grid
        p_rw = self.entropy(rw, self.rw_grid, bandwidth=1.0, n_points=n_points)
        p_labels = self.entropy(labels, self.label_grid, bandwidth=1.0, n_points=n_points)
        p_joint = self.entropy(joint, self.joint_grid, bandwidth=1.0, n_points=n_points)
        mi=(p_rw + p_labels - p_joint)
        del rw
        del joint
        del labels
        del rw_grid
        del label_grid
        del joint_grid
        del walks
        del self.walks
        del self.rw
        del self.rw_grid
        del self.label_grid
        del self.joint_grid
        torch.cuda.empty_cache()
        return -mi

    def forward(self, feature, device, train_idx, walk_time):
        if self.training:
            start_time = time.time()
            self.embedding = self.gnn1(feature)
            if self.data_name == "arxivdir":
                row = []
                col = []
                weight = []
                for i in range(self.n):
                    for j in self.graph_a.neighbors(i):
                        row.append(i)
                        col.append(j)
                        tmp = self.embedding[i].detach()@self.embedding[j].t().detach()
                        weight.append(tmp)
                index = np.vstack((row, col))
                index = torch.Tensor(index).long()
                weight = torch.FloatTensor(weight)
                self.cosine_similarity_matrix = torch.sparse_coo_tensor(
                    index,  
                    weight,  
                    (self.n, self.n),
                    device=device
                )
                self.cosine_similarity_matrix = self.cosine_similarity_matrix.coalesce()
            else:
                self.cosine_similarity_matrix = torch.zeros(self.n, self.n).to(device)
                self.cosine_similarity_matrix.to(device)
                for i in range(self.n):
                    for j in self.graph_a.neighbors(i):
                        self.cosine_similarity_matrix[i,j]=self.embedding[i].detach()@self.embedding[j].t().detach()
            self.rw = self.generate_random_walk(train_idx[0], device)
            self.losses = self.loss(self.walks, self.rw, device)
            self.losses = self.losses.unsqueeze_(0)
            for idx in train_idx:
                if self.graph_a.neighbors(idx.item()) is None:
                    continue
                if idx==train_idx[0]:
                    for th in range (walk_time-1):
                        self.rw = self.generate_random_walk(idx, device)
                        tmp_loss = self.loss(self.walks, self.rw, device)
                        tmp_loss = tmp_loss.unsqueeze_(0)
                        self.losses = torch.cat((self.losses, tmp_loss), dim=0)
                else:
                    for th in range (walk_time):
                        self.rw = self.generate_random_walk(idx, device)
                        tmp_loss = self.loss(self.walks, self.rw, device)
                        tmp_loss = tmp_loss.unsqueeze_(0)
                        self.losses = torch.cat((self.losses, tmp_loss), dim=0)
            del self.cosine_similarity_matrix
            torch.cuda.empty_cache()
            return self.losses
        else:
            start_time = time.time()
            self.embedding = self.gnn1(feature)
            self.gnn1.to(device)
            self.gnn2.to(device)
            if self.cosine_similarity_matrix is None:
                if self.data_name == "arxivdir":
                    row = []
                    col = []
                    weight = []
                    for i in range(self.n):
                        for j in self.graph_a.neighbors(i):
                            row.append(i)
                            col.append(j)
                            tmp = self.embedding[i].detach()@self.embedding[j].t().detach()
                            weight.append(tmp)
                    index = np.vstack((row, col))
                    index = torch.Tensor(index).long()
                    weight = torch.FloatTensor(weight)
                    self.cosine_similarity_matrix = torch.sparse_coo_tensor(
                        index,  
                        weight,  
                        (self.n, self.n),
                        device=device
                    )
                    self.cosine_similarity_matrix = self.cosine_similarity_matrix.coalesce()
                else:
                    self.cosine_similarity_matrix = torch.zeros(self.n, self.n).to(device)
                    for i in range(self.n):
                        for j in self.graph_a.neighbors(i):
                            self.cosine_similarity_matrix[i,j]=self.embedding[i].detach()@self.embedding[j].t().detach()
            self.rw = self.generate_random_walk(train_idx, device)
            del self.rw_grid
            del self.label_grid
            del self.joint_grid
            torch.cuda.empty_cache()
            return self.walks, self.rw

class LRW_MLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(LRW_MLP, self).__init__()
        self.lins = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        if num_layers == 1:
            self.lins.append(torch.nn.Linear(in_channels, out_channels))
        else:
            self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
            for _ in range(num_layers - 2):
                self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
                self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
            self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, query_edges=None):
        if query_edges is not None:
            x = torch.cat((x[query_edges[:, 0]], x[query_edges[:, 1]]), dim=-1)
        for i, lin in enumerate(self.lins[:-1]):
            x = lin(x)
            x = nn.functional.relu(x, inplace=True)
            x = self.bns[i](x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)

        return x
    
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, save_mem=True, use_bn=False):
        super(GCN, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(
            GCNConv(in_channels, hidden_channels, cached=not save_mem, normalize=True))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=not save_mem, normalize=True))
        self.convs.append(
            GCNConv(hidden_channels, out_channels, cached=not save_mem, normalize=True))

        self.dropout = dropout
        self.activation = F.relu
        self.use_bn = use_bn

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()


    def forward(self, x, edge_index, edge_weight=None):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index, edge_weight)
            '''
            if self.use_bn:
                x = self.bns[i](x)
            '''
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

class SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2,
                 dropout=0.5, use_bn=True):
        super(SAGE, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(
            SAGEConv(in_channels, hidden_channels))

        self.bns = nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        self.convs.append(
            SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout
        self.activation = F.relu
        self.use_bn = use_bn

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()


    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            if self.use_bn:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x


class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2,
                 dropout=0.5, heads=2):
        super(GAT, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(
            GATConv(in_channels, hidden_channels, heads=heads, concat=True))

        for _ in range(num_layers - 2):

            self.convs.append(
                    GATConv(hidden_channels*heads, hidden_channels, heads=heads, concat=True) ) 

        self.convs.append(
            GATConv(hidden_channels*heads, out_channels, heads=heads, concat=False))

        self.dropout = dropout
        self.activation = F.elu 

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()


    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x