
import numpy as np
import torch
from snorkel.labeling.model import LabelModel
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import radius_neighbors_graph
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import manhattan_distances

##### Accuracy #####
def NotAbstainAcc(pred, labels):
    '''
    Function that computes accuracy on non-abstained regions 
    '''    
    valid_inds = np.abs((pred[:,0] - 0.5)) > 0.001
    acc = Acc(pred[valid_inds], labels[valid_inds])
    return acc
    
def AdjustAcc(pred, labels):
    '''
    Function that computes accuracy on all regions (i.e., setting acc on abstained regions to 50%)
    '''
    
    acc = NotAbstainAcc(pred, labels)
    valid_inds = np.abs((pred[:,0] - 0.5)) > 0.001
    n = np.sum(valid_inds)
    N = labels.shape[0]
    return acc * n / N + 0.5 * (N - n) / N

def Acc(pred, labels):
    '''
    Function that computes given predictions and true labels
    '''    
    return np.mean(np.argmax(pred, axis = 1) == np.array(labels))

def get_correct(softmax, y):
    
    correct = 0
    num_sample = y.shape[0]
    abs_idx = (softmax.max(1).values < .501)
    pred = softmax.max(1, keepdim=True)[1]
    correct += pred[~abs_idx].eq(y[~abs_idx].view_as(pred[~abs_idx])).sum().item()
    correct += int(0.5*abs_idx.float().sum())
    return correct, num_sample

##### Euclidean graph construction #####
def get_transition_mat(combined_dis_mat, threshold):
    '''
    Function that generates a adjacency matrix given a distance matrix and a particular threshold
    '''

    # Radius Neighbor 
    radius_mat = radius_neighbors_graph(combined_dis_mat, threshold, metric='precomputed', include_self=True)
    radius_mat = radius_mat.toarray()
    return radius_mat

def get_transition_mat_nn(combined_dis_mat, n_neighbors = 10):
    '''
    Function that generates a adjacency matrix given a distance matrix and a number of nearest neighbors (not used)
    '''

    # Nearest Neighbor
    neigh = NearestNeighbors(n_neighbors=5, metric="precomputed")
    neigh.fit(combined_dis_mat)
    NN_mat = neigh.kneighbors_graph(combined_dis_mat)
    NN_mat = NN_mat.toarray()
    return NN_mat

def GenerateMatrix(euc_mat, thresh = 10):
    '''
    Function to genereate an adjacency and a normalized adjacency matrix
    '''

    np.seterr(invalid='ignore')
    comb_mat = euc_mat
    N = euc_mat.shape[0]
    threshold = np.quantile(comb_mat, thresh/N)
    adj_mat= get_transition_mat(comb_mat, threshold)

    np.seterr(invalid='ignore')
    D_inv = 1/adj_mat.sum(axis = 1) * np.identity(adj_mat.shape[0])
    D_inv = np.nan_to_num(D_inv)
    # doing S = D^{-1/2} A D^{-1/2}
    S = np.matmul(np.matmul(np.sqrt(D_inv), adj_mat), np.sqrt(D_inv))
    S = np.nan_to_num(S)

    return adj_mat, S

def normalize_matrix(adj_mat):
    '''
    Helper function to normalized adjacency matrix
    '''    
    
    # full transition matrix
    # adj_mat = comb_mat
    np.seterr(invalid='ignore')
    D_inv = 1/adj_mat.sum(axis = 1) * np.identity(adj_mat.shape[0])
    D_inv = np.nan_to_num(D_inv)
    # doing S = D^{-1/2} A D^{-1/2}
    S = np.matmul(np.matmul(np.sqrt(D_inv), adj_mat), np.sqrt(D_inv))
    S = np.nan_to_num(S)
    return S


##### Neighbors #####

def GetNeighbor(A, S):

    '''
    Function to get the neighbors of a set A given graph S
    '''

    # look at the k hop neighbor of label points
    Neighbor_A = set()
    for v in A:
        Neighbor_A.update(np.where(S[v] != 0)[0])
    Neighbor_A = set(A).union(Neighbor_A)
    return np.array(list(Neighbor_A))

def KhopNeighbor(A,S, k):
    '''
    (Efficient version) Function to get the neighbors of a set A given graph S
    '''

    # more efficient code
    nb_list = [A, GetNeighbor(A,S)]
    for _ in range(k-1):
        new_points = list(set(nb_list[-1]) - set(nb_list[-2]))
        new_nb = GetNeighbor(new_points,S)
        total_nb = list(set(list(nb_list[-1]) + list(new_nb)))
        nb_list.append(total_nb)

        if len(nb_list[-1]) == len(nb_list[-2]):
            break
    return np.array(nb_list[-1])


def Flip(a):
    new_a = -1*np.ones_like(a)
    new_a[a == 1] = 0
    new_a[a == 0] = 1
    return new_a

def Flip_L(L, L_acc):
    new_L = -1*np.ones_like(L)
    for i in range(L.shape[1]):
        if L_acc[i] > 50:
            new_L[:,i] = L[:,i]
        else:
            new_L[:,i] = Flip(L[:,i])

    return torch.tensor(new_L)


def CheckLFs_Acc(L, labels, show = False):
    Acc_oracle = []
    Coverage = []
    for i in range(L.shape[1]):
        correct = (L[:,i] == labels).sum()
        not_abstain = (L[:,i] != -1).sum()
        if show:
            print((100*correct/not_abstain).item(), 'coverage', not_abstain.sum().item())
        Acc_oracle.append((100*correct/not_abstain).item())
        Coverage.append(not_abstain.sum().item()/L.shape[0])

    Acc_oracle= np.array(Acc_oracle)

    return Acc_oracle, Coverage

def Snorkel(L):
    label_model = LabelModel(cardinality=2, verbose=False)
    label_model.fit(L_train=L, n_epochs=200, log_freq=200)
    Acc_snorkel = label_model.get_weights()*100
    snorkel_pred = label_model.predict_proba(L=L)

    return Acc_snorkel, snorkel_pred

def GetStats(snorkel_pred, labels, show = False):
    pred_idx = snorkel_pred.max(1) > 0.5001
    Accuracy = Acc(snorkel_pred[pred_idx,:], labels[pred_idx])*100
    # Accuracy = NotAbstainAcc(snorkel_pred, labels)*100

    Coverage = 100*(snorkel_pred.max(1) > 0.5001).sum()/snorkel_pred.shape[0]
    AdjustedAcc = (Accuracy * Coverage + 50 * (100 - Coverage))/100
    if show:
        print('Not Abstain Acc: {:.2f}'.format(Accuracy))
        print('Coverage: {:.2f}'.format(Coverage))
        print('Adjusted Acc: {:.2f}'.format(AdjustedAcc))
    return Accuracy, Coverage, AdjustedAcc

