import torch, torch_geometric, warnings
import numpy as np
import networkx as nx
from functools import partial
import torch.nn.functional as F
from scipy.sparse import csgraph
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from sklearn.metrics import adjusted_mutual_info_score, pairwise_distances

device = "cuda" if torch.cuda.is_available() else "cpu"
warnings.simplefilter(action = "ignore", category = FutureWarning)

class GAE(torch.nn.Module):
    def __init__(self, n_obs):
        super().__init__()
        self.enc_conv1 = torch_geometric.nn.SAGEConv(n_obs, 128).double()
        self.enc_conv2 = torch_geometric.nn.SAGEConv(128, 64).double()
        self.enc_conv3 = torch_geometric.nn.SAGEConv(64, 10).double()
        self.dec_conv1 = torch_geometric.nn.SAGEConv(10, 64).double()
        self.dec_conv2 = torch_geometric.nn.SAGEConv(64, 128).double()
        self.dec_conv3 = torch_geometric.nn.SAGEConv(128, n_obs).double()
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.enc_conv1(x, edge_index)
        x = F.softsign(x)
        x = self.enc_conv2(x, edge_index)
        x = F.softsign(x)
        x = self.enc_conv3(x, edge_index)
        x = self.dec_conv1(x, edge_index)
        x = F.softsign(x)
        x = self.dec_conv2(x, edge_index)
        x = F.softsign(x)
        x = self.dec_conv3(x, edge_index)
        return torch.sigmoid(x)

def _make_tensor(G, observations, ground_truth):
   edge_index = [[], []]
   for edge in G.edges:
      edge_index[0].append(edge[0])
      edge_index[1].append(edge[1])
      edge_index[0].append(edge[1])
      edge_index[1].append(edge[0])
   edge_index = torch.tensor(edge_index, dtype = torch.long)
   x = torch.tensor(observations.T, dtype = torch.float)
   y = ground_truth
   tensor = torch_geometric.data.Data(x = x, y = y, edge_index = edge_index)
   tensor.x = tensor.x.double().to(device)
   tensor.edge_index = tensor.edge_index.to(device)
   return tensor

def _run_autoencoder(data, epochs = 100):
   autoencoder = GAE(data.x.shape[1]).to(device)
   loss_function = torch.nn.CrossEntropyLoss()
   optimizer = torch.optim.Adam(autoencoder.parameters(), lr = 1e-3)
   for epoch in range(epochs):
      reconstructed = autoencoder(data)
      loss = loss_function(reconstructed, data.x)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
   return reconstructed.T

def _train_n2v(model, optimizer, loader):
   model.train()
   for pos_rw, neg_rw in loader:
      optimizer.zero_grad()
      loss = model.loss(pos_rw, neg_rw)
      loss.backward()
      optimizer.step()

def _run_shallow_gnn(data, epochs = 20):
   n2v = torch_geometric.nn.Node2Vec(
      edge_index = data.edge_index,
      embedding_dim = 128,
      walk_length = 6,
      context_size = 4,
      walks_per_node = 10,
      num_negative_samples = 1,
      p = 1,
      q = 1,
      sparse = True
   )
   optimizer = torch.optim.SparseAdam(list(n2v.parameters()), lr = 1)
   loader = n2v.loader(batch_size = 128, shuffle = True, num_workers = 4)
   for epoch in range(epochs):
      _train_n2v(n2v, optimizer, loader)
   node_embeddings = n2v(torch.arange(data.x.shape[0])).double()
   return torch.mm(data.x.T.cpu(), node_embeddings).detach().numpy()

def _ge(Q, src, trg):
   diff = src - trg
   return np.sqrt(diff.T.dot(Q.dot(diff)))

def _ge_Q(G):
   A = nx.adjacency_matrix(G).todense().astype(float)
   return np.linalg.pinv(csgraph.laplacian(np.matrix(A), normed = False))

def _ge_gpu(Q, src, trg):
   diff = src - trg
   return torch.sqrt((diff * torch.mm(Q, diff.T).T).sum(dim = 1))

def _ge_Q_gpu(G):
   A = nx.adjacency_matrix(G).todense().astype(float)
   Q = np.linalg.pinv(csgraph.laplacian(np.matrix(A), normed = False))
   return torch.from_numpy(Q).double().to(device)

def _pairwise_gpu(X, metric = None, n_jobs = None):
   if metric is None:
      return pairwise_distances(X, n_jobs = n_jobs)
   else:
      out = torch.zeros((X.shape[0], X.shape[0])).to(device)
      for i in range(X.shape[0]):
         out[i,i + 1:] = metric(X[i], X[(i + 1):])
      out = out + out.T
      return out

def _clustering_performance(distance_matrix, ground_truth):
   perfs = []
   clusters = []
   for eps in np.linspace(distance_matrix[distance_matrix > 0].min(), distance_matrix.mean(), num = 50):
      clusters.append(DBSCAN(eps = eps, min_samples = 2, metric = "precomputed").fit(distance_matrix))
      unclassified_indexes = np.where(clusters[-1].labels_ == -1)
      clusters[-1].labels_[unclassified_indexes] = np.arange(clusters[-1].labels_.max() + 1, clusters[-1].labels_.max() + 1 + unclassified_indexes[0].size)
      perfs.append(adjusted_mutual_info_score(clusters[-1].labels_, ground_truth))
      if perfs[-1] >= 1:
         return 1
   return max(perfs)

def _make_G(n_nodes, n_comms, d_out):
   p_in = (20 - d_out) / 49
   p_out = d_out / (n_nodes - 50)
   sizes = [n_nodes // n_comms] * n_comms
   ps = np.full((n_comms, n_comms), p_out)
   np.fill_diagonal(ps, p_in)
   G = nx.stochastic_block_model(sizes, ps)
   while nx.number_connected_components(G) > 1:
      G = nx.stochastic_block_model(sizes, ps)
   H = nx.Graph()
   H.add_nodes_from(sorted(G.nodes))
   H.add_edges_from(G.edges)
   return H

def _make_observations(G, n_comms, n_obs, noise):
   ground_truth = []
   observations = []
   size = len(G.nodes) // n_comms
   clusters = list(range(n_comms))
   ground_truth = np.random.choice(clusters, size = n_obs, replace = True)
   for n in range(n_obs):
      mask = (np.array(G.nodes) // 50) == ground_truth[n]
      observation = np.zeros(len(G.nodes))
      observation[mask] = np.random.uniform(low = 0.5, high = 1, size = mask.sum())
      observation[~mask] = np.random.uniform(low = 0.0, high = 0.5, size = (~mask).sum())
      observation += np.random.normal(scale = noise, size = len(G.nodes))
      observations.append(observation)
   return np.array(observations), ground_truth

def benchmark_data(n_nodes, d_out, n_obs, noise):
   G = _make_G(n_nodes, n_nodes // 50, d_out)
   observations, ground_truth = _make_observations(G, n_nodes // 50, n_obs, noise)
   tensor = _make_tensor(G, observations, ground_truth)
   return G, tensor, _ge_Q_gpu(G) if device == "cuda" else _ge_Q(G)

def compute_distances(tensor, method, Q = None):
   if device == "cuda":
      ge = _ge_gpu
      pairwise = _pairwise_gpu
   else:
      ge = _ge
      pairwise = pairwise_distances
   if method == "nvd":
      distance_matrix = pairwise(tensor.x.T, metric = partial(ge, Q), n_jobs = -1).cpu().detach().numpy()
   elif method == "flat":
      distance_matrix = pairwise(tensor.x.T.cpu(), n_jobs = -1)
   elif method == "emb":
      embs = TSNE(n_components = 2, n_jobs = -1).fit(tensor.x.T.cpu())
      distance_matrix = pairwise(embs.embedding_, n_jobs = -1)
   elif method == "nvd+emb":
      dists = compute_distances(tensor, "nvd", Q = Q)
      embs = TSNE(n_components = 2, metric = "precomputed", n_jobs = -1).fit(dists)
      distance_matrix = pairwise(embs.embedding_, n_jobs = -1)
   elif method == "gcn":
      distance_matrix = pairwise(_run_autoencoder(tensor).cpu().detach().numpy(), n_jobs = -1)
   elif method == "gcn+nvd":
      distance_matrix = pairwise(_run_autoencoder(tensor), metric = partial(ge, Q), n_jobs = -1).cpu().detach().numpy()
   elif method == "gcn+emb":
      embs = TSNE(n_components = 2, n_jobs = -1).fit(_run_autoencoder(tensor).cpu().detach().numpy())
      distance_matrix = pairwise(embs.embedding_, n_jobs = -1)
   elif method == "nvd+gcn+emb":
      _tensor = tensor = torch_geometric.data.Data(x = _run_autoencoder(tensor).T)
      dists = compute_distances(_tensor, "nvd", Q = Q)
      embs = TSNE(n_components = 2, metric = "precomputed", n_jobs = -1).fit(dists)
      distance_matrix = pairwise(embs.embedding_, n_jobs = -1)
   elif method == "n2v":
      distance_matrix = pairwise(_run_shallow_gnn(tensor), n_jobs = -1)
   elif method == "n2v+emb":
      embs = TSNE(n_components = 2, n_jobs = -1).fit(_run_shallow_gnn(tensor))
      distance_matrix = pairwise(embs.embedding_, n_jobs = -1)
   return distance_matrix

def cluster(distance_matrix, eps = 1, min_samples = 2, benchmark = False, ground_truth = None):
   if benchmark:
      return _clustering_performance(distance_matrix, ground_truth)
   else:
      return DBSCAN(eps = eps, min_samples = min_samples, metric = "precomputed").fit(distance_matrix)
