import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from scipy.sparse import coo_matrix
import time


class LearnableRandomWalkBaseModel(nn.Module):
    def __init__(self, n, data_name, y):
        super(LearnableRandomWalkBaseModel, self).__init__()
        self.n = n
        self.data_name = data_name
        self.y = y
        labels = y
        n_points = 10
        labels = labels.float().reshape(self.n, 1)
        
        self.base_model = None
        self.tau = 0.5
        self.post_graph_op = None
        self.random_walk_time=[]
        self.loss_cal_time=[]
        self.rw_grid = None
        self.label_grid = None
        self.joint_grid = None
        self.device = None
        
        self.rw_grid = None
        self.label_grid = None
        self.joint_grid = None

    def kde(self, x, x_grid, bandwidth=1.0, n_points=1000):
        n = x.size(0)*x.size(1)
        diff = x.reshape(n, 1) - x_grid.reshape(1, n_points)
        kernel = torch.exp(-0.5 * (diff / bandwidth) ** 2)
        density = kernel.sum(dim=0) / (n * bandwidth * torch.sqrt(torch.tensor(2 * torch.pi)))
        return density
        
    def entropy(self, x, x_grid, bandwidth=1.0, n_points=1000):
        density = self.kde(x, x_grid, bandwidth, n_points)
        density_log = torch.log(density + 1e-10)  # Add small value to avoid log(0)
        entropy = -torch.sum(density * density_log) * (x.max() - x.min()) / n_points
        return entropy
    
    def preprocess(self, adj, feature):
        start_time = time.time()
        adj = adj.tocoo()
        self.adj_a = adj
        self.base_model.adj_a = self.adj_a
        self.base_model.graph_a = nx.Graph()
        self.base_model.y=self.y
        nodes = range(self.n)
        self.x = torch.FloatTensor(feature)
        self.base_model.graph_a.add_nodes_from(nodes)
        edges = np.stack((self.adj_a.row, self.adj_a.col), axis=1)
        self.base_model.graph_a.add_edges_from(edges)
        end_time = time.time()

    def postprocess(self, adj, output):
        return output

    def model_forward(self, device, train_idx, walk_time=1, ori=None):
        return self.forward(device, train_idx, walk_time, ori)

    def forward(self, device, train_idx, walk_time=1, ori=None):
        self.x = self.x.to(device)
        if ori is not None:
            self.base_model.query_edges = ori
        losses = self.base_model(self.x, device, train_idx, walk_time)
        return losses

    
class ClassifierBaseModel(nn.Module):
    def __init__(self, n, data_name, walk_len, device):
        super(ClassifierBaseModel, self).__init__()
        self.n = n
        self.data_name = data_name
        self.base_model = None
        self.post_graph_op = None
        self.x = None
        self.walk_op = None
        self.walk_len = walk_len
        self.multineighbors = torch.zeros(self.n, self.walk_len+1)
        self.cof = 0.75
        self.weights = (self.cof ** np.arange(self.walk_len+1, dtype=np.float32))
        self.weights /= self.weights.sum()
        self.device = device

    def preprocess(self, adj, feature):
        self.x = feature
        for i in range(self.n):
            walk, rw = self.walk_op.model_forward(self.device, i, 1)
            while walk.shape[0]<self.walk_len+1:
                walk = torch.cat((walk, torch.tensor(i).reshape(1)))
            #print(walk)
            for j in range(self.walk_len+1):
                self.multineighbors[i,j] = walk[j].int()

        row = np.arange(self.n)
        data = np.ones(self.n)
        for i in range(self.walk_len+1):
            col = self.multineighbors[:,i].numpy()
            geom = coo_matrix((data, (row, col)), shape=(self.n, self.n))
            feature = geom @ feature
            self.x = torch.cat((self.x, torch.tensor((feature)*self.weights[i])), dim=1)

        self.x = self.x.float()
        self.x = torch.FloatTensor(self.x)

    def postprocess(self, adj, output):
        return output

    def model_forward(self, device, idx, ori=None):
        return self.forward(device, idx, ori)

    def forward(self, device, idx, ori):
        self.x = self.x.to(device)
        if ori is not None:
            self.base_model.query_edges = ori
        output= self.base_model(self.x, device)
        return output[idx]
    