import torch
import numpy as np
from torch import nn
import networkx as nx
import scipy.sparse as sp


def map_row(H1, H2, n, row, verbose=False):
    edge_mask = (H1 == 1.0)
    m = np.sum(edge_mask).astype(int)
    assert m > 0
    if verbose: print(f"\t There are {m} edges for {row} of {n}")
    d = H2
    sorted_dist = np.argsort(d)
    if verbose:
        print(f"\t {sorted_dist[0:5]} vs. {np.array(range(n))[edge_mask]}")
        print(f"\t {d[sorted_dist[0:5]]} vs. {H1[edge_mask]}")
    precs       = np.zeros(m)
    n_correct   = 0
    j = 0
    # skip yourself, you're always the nearest guy
    # TODO (A): j is redundant here
    for i in range(1,n):
        if edge_mask[sorted_dist[i]]:
            n_correct += 1
            precs[j] = n_correct/float(i)
            j += 1
            if j == m:
                break
    return np.sum(precs)/m


def map_score(H1, H2, n, jobs):
    maps  = [map_row(H1[i,:],H2[i,:],n,i) for i in range(n)]
    return np.sum(maps)/n


class TaskModel(nn.Module):
    def __init__(self, args, encoder) -> None:
        super().__init__()

        self.args = args
        self.encoder = encoder

    def compute_metrics(self, embeddings, data, split):
        device = embeddings.get_device()
        num, dim = embeddings.size(0), embeddings.size(1)
        adj = torch.Tensor(data['adj_train'].A).to(device)
        positive = adj.bool()
        negative = ~positive

        x_1 = embeddings.repeat(num,1)
        x_2 = embeddings.repeat_interleave(num,0)

        emb_dist = self.encoder.sqdist(x_1, x_2).view(num, num)
        simi = torch.clamp(torch.exp(-emb_dist), min=1e-15)
        positive_sim = simi * (positive.long())
        negative_sim = simi * (negative.long())
        
        negative_sum = negative_sim.sum(dim=1).unsqueeze(1).repeat(1, num)
        loss = torch.clamp(torch.div(positive_sim, negative_sum)[positive], min=1e-15)
        loss = (-torch.log(loss)).sum() 

        G = data['G']
        n = G.order()
        G = nx.to_scipy_sparse_array(G, nodelist=list(range(G.order())))
        true_dist = (torch.Tensor(data['labels'])).to(device)

        mask = np.array(np.triu(np.ones((true_dist.shape[0],true_dist.shape[0]))) - np.eye(true_dist.shape[0], true_dist.shape[0]), dtype=bool)
        mapscore = map_score(sp.csr_matrix.todense(G), emb_dist.cpu().detach().numpy(), n, 16)

        true_dist = true_dist[mask] 
        emb_dist = emb_dist[mask]

        metrics = {'loss': loss, 'mapscore': mapscore}
        return metrics

    def init_metric_dict(self):
        return {'mapscore': -1}

    def has_improved(self, m1, m2):
        return m1["mapscore"] < m2["mapscore"]

