import numpy as np
import pandas as pd
import seaborn as sns
import torch
import matplotlib.pyplot as plt

from main import load_data
from utils import Acc, NotAbstainAcc, AdjustAcc
from label_prop import PropagationSoft, PropagationHard
from sklearn.kernel_ridge import KernelRidge
from snorkel.labeling.model import LabelModel



def CheckLFs_Acc(L, labels, show = False):
    Acc_oracle = []
    Coverage = []
    for i in range(L.shape[1]):
        correct = (L[:,i] == labels).sum()
        not_abstain = (L[:,i] != -1).sum()
        if show:
            print((100*correct/not_abstain).item(), 'coverage', not_abstain.sum().item())
        Acc_oracle.append((100*correct/not_abstain).item())
        Coverage.append(not_abstain.sum().item()/L.shape[0])

    Acc_oracle= np.array(Acc_oracle)

    return Acc_oracle, Coverage

def Snorkel(L):
    label_model = LabelModel(cardinality=2, verbose=False)
    label_model.fit(L_train=L, n_epochs=200, log_freq=200)
    Acc_snorkel = label_model.get_weights()*100
    snorkel_pred = label_model.predict_proba(L=L)

    return Acc_snorkel, snorkel_pred

def GetStats(snorkel_pred, labels, show = False):
    pred_idx = snorkel_pred.max(1) > 0.5001
    Accuracy = Acc(snorkel_pred[pred_idx,:], labels[pred_idx])*100
    # Accuracy = NotAbstainAcc(snorkel_pred, labels)*100

    Coverage = 100*(snorkel_pred.max(1) > 0.5001).sum()/snorkel_pred.shape[0]
    AdjustedAcc = (Accuracy * Coverage + 50 * (100 - Coverage))/100
    if show:
        print('Not Abstain Acc: {:.2f}'.format(Accuracy))
        print('Coverage: {:.2f}'.format(Coverage))
        print('Adjusted Acc: {:.2f}'.format(AdjustedAcc))
    return Accuracy, Coverage, AdjustedAcc

def GenerateDongleW(W_x, L):
    n = W_x.shape[0]
    m = L.shape[1]  
    W_dongle = np.zeros((n+2*m, n+2*m))
    W_dongle[:n, :n] = W_x
    W_dongle[:n, n:n+m] = (L== 0).float().numpy()
    W_dongle[n:n+m, :n] = np.transpose((L== 0).float().numpy())
    W_dongle[:n, n+m:n+2*m] = (L== 1).float().numpy()
    W_dongle[n+m:n+2*m, :n] = np.transpose((L== 1).float().numpy())
    return W_dongle

def LPA_with_dongle(W_x, L, L_acc, labels, num_labels = 0, lamb = 1):

    m = L.shape[1]
    n = W_x.shape[0]

    labeled_inds = np.random.choice(range(n), size= num_labels, replace=False)

    return LPA_with_dongle_with_labeled_inds(W_x, L, L_acc, labels, labeled_inds , lamb)

def LPA_with_dongle_with_labeled_inds(W_x, L, L_acc, labels, labeled_inds, lamb = 1):
    m = L.shape[1]
    n = W_x.shape[0]

    W_dongle = GenerateDongleW(W_x, L)
    

    # Create a weighted augmented dongle matrix
    Acc_mat = (np.ones((n,2*m))*np.hstack([L_acc, L_acc]))/100
    W_dongle_weighted = np.zeros_like(W_dongle)
    W_dongle_weighted[:n, :n] = W_dongle[:n, :n]
    W_dongle_weighted[:n, n:n+2*m] = W_dongle[:n, n:n+2*m]* Acc_mat * lamb
    W_dongle_weighted[n:n+2*m, :n] = np.transpose(W_dongle[:n, n:n+2*m]* Acc_mat)*lamb

    base_preds = 0.5*np.ones((n+2*m, 2))
    labels_aug = np.hstack([labels.numpy(), np.hstack([np.zeros(m),np.ones(m)])])
    labeled_inds_aug = np.hstack([labeled_inds, np.arange(n, n+2*m)]).astype(int)
    smooth_wl = PropagationHard(base_preds, W_dongle_weighted, labels = labels_aug , labeled_inds = labeled_inds_aug, alpha = 1)

    return smooth_wl[:n,:]


def LPA_with_dongle_with_labeled_inds_custom_alpha(W_x, L, alpha_mat, labels, labeled_inds, lamb = 1):
    m = L.shape[1]
    n = W_x.shape[0]

    W_dongle = GenerateDongleW(W_x, L)
    

    # Create a weighted augmented dongle matrix
    Acc_mat = np.hstack([alpha_mat, alpha_mat])
    W_dongle_weighted = np.zeros_like(W_dongle)
    W_dongle_weighted[:n, :n] = W_dongle[:n, :n]
    W_dongle_weighted[:n, n:n+2*m] = W_dongle[:n, n:n+2*m]* Acc_mat * lamb
    W_dongle_weighted[n:n+2*m, :n] = np.transpose(W_dongle[:n, n:n+2*m]* Acc_mat)*lamb

    base_preds = 0.5*np.ones((n+2*m, 2))
    labels_aug = np.hstack([labels.numpy(), np.hstack([np.zeros(m),np.ones(m)])])
    labeled_inds_aug = np.hstack([labeled_inds, np.arange(n, n+2*m)]).astype(int)
    smooth_wl = PropagationHard(base_preds, W_dongle_weighted, labels = labels_aug , labeled_inds = labeled_inds_aug, alpha = 1)

    return smooth_wl[:n,:]

def LPA_with_dongle_with_custom_alpha(W_x, L, alpha_mat, labels, num_labels = 0, lamb = 1):

    m = L.shape[1]
    n = W_x.shape[0]

    labeled_inds = np.random.choice(range(n), size= num_labels, replace=False)

    return LPA_with_dongle_with_labeled_inds_custom_alpha(W_x, L, alpha_mat, labels, labeled_inds, lamb)


def alpha_from_LPA(X, L, labels, L_acc, labeled_inds, thresh = 5, alpha_LPA = 1000):

    # Update snorkel
    Alpha = np.zeros_like(L).astype(float)
    for i in range(L.shape[1]):
        idx = (L[:,i] !=-1)
        try:
            N = X[idx].shape[0]
            if N >= thresh + 1:
                # Construct a euclidean graph
                euc_mat = pairwise_distances(X[idx], Y = None, metric='euclidean') 
                threshold = np.quantile(euc_mat, thresh/N)
                A= get_transition_mat(euc_mat, threshold)

                idx_i = np.array([i  for i in range(idx.shape[0]) if idx[i] == True] )
                labeled_inds_i_array = [i in labeled_inds for i in idx_i]
                labeled_inds_i = [j for j in range(N) if labeled_inds_i_array[j] == True]
                # LPA
                labels_i = (L[idx,i] == labels[idx]).long()
                L_acc_i = L_acc[i]* np.ones((X[idx].shape[0]))/100
                L_acc_i[labeled_inds_i] = labels_i[labeled_inds_i]
                base_acc_i = np.stack((1- L_acc_i, L_acc_i), axis = 1)
                alpha_i = PropagationHard(base_acc_i, A, labels = labels_i , labeled_inds = labeled_inds_i, alpha = alpha_LPA)[:,1]
                Alpha[idx,i] = alpha_i
            else:
                L_acc_i = L_acc[i]* np.ones((X[idx].shape[0]))/100
                Alpha[idx,i] = L_acc_i

        except:
            L_acc_i = L_acc[i]* np.ones((X[idx].shape[0]))/100
            Alpha[idx,i] = L_acc_i   
    return Alpha


def Adaboost_weight(w, clip = 5, graph_acc = 0.95):
    w_clip = np.clip(w, a_min = 0.01*clip, a_max = 1 - 0.01*clip)
    return (0.5*np.log(w_clip/(1-w_clip)))/(0.5*np.log(graph_acc/(1-graph_acc)))

def Adaboost_weight_norm(w, clip = 5):
    w_clip = np.clip(w, a_min = 0.01*clip, a_max = 1 - 0.01*clip)
    alpha_j = (0.5*np.log(w_clip/(1-w_clip)))
    alpha_j_norm = alpha_j/np.max(alpha_j)
    return alpha_j_norm*100


def Generate_data_var_reg(X,L, labels, labeled_inds, i, clip = 1e-5):
    unlabeled_inds = np.array(list(set(range(X.shape[0])) - set(labeled_inds)))
    idx = (L[labeled_inds,i]!= -1) # idx for points that the weak labeler i make a prediciton
    X_train = X[labeled_inds][idx]
    labels_train = (labels[labeled_inds][idx])
    L_train = L[labeled_inds,i][idx]
    y_train = np.log(np.square(L_train - labels_train) + 1e-5)

    # Generate test data
    idx_test = (L[unlabeled_inds,i] != -1)
    X_test = X[unlabeled_inds][idx_test]
    labels_test = (labels[unlabeled_inds][idx_test])
    L_test = L[unlabeled_inds,i][idx_test]
    y_test = np.log(np.square(L_test - labels_test) + clip)

    return X_train, y_train, X_test,y_test

def alpha_from_reg(X,L, labels, L_acc, labeled_inds, kernel = 'linear'):
    Alpha = np.zeros_like(L).astype(float)
    for i in range(L.shape[1]):
        X_train, y_train, X_test, y_test = Generate_data_var_reg(X,L, labels, labeled_inds, i, clip = 1e-10)
        krr = KernelRidge(alpha=1, kernel = kernel, gamma = 1)
        if X_train.shape[0] > 0 :
            if y_train.unique().shape[0] > 1:
                krr.fit(X_train, y_train)
                # print(krr.score(X_test, y_test))
                Alpha[:,i] = 1/np.exp(krr.predict(X)/2)
            else:
                Alpha[:,i] = 316.22*L_acc[i]/100.0
        else:
            Alpha[:,i] = 316.22*L_acc[i]/100.0
    Alpha_reg = np.clip(Alpha, a_min = 1 , a_max = 316.22)/316.22
    return Alpha_reg


def alpha_from_cotrain(Alpha_0, X, L, L_acc, labeled_inds, labels, threshold = 1, confidence_level = 0.01):
    Alpha = np.zeros_like(L).astype(float)
    for i in range(L.shape[1]):
        pseudolabels_not_i = Gen_pseudolabel_not_i(L, Alpha_0, i, labels,threshold = threshold, confidence_level = confidence_level)
        try:
            X_train, y_train = Generate_data_var_reg_cotrain(X,L, labels, pseudolabels_not_i, labeled_inds, i)
            krr = KernelRidge(alpha=1, kernel = 'linear', gamma = 1)
            if X_train.shape[0] > 0 :
                # print(y_train.shape)
                krr.fit(X_train, y_train)
                # print(krr.score(X_test, y_test))
                Alpha[:,i] = 1/np.exp(krr.predict(X)/2)
            else:
                Alpha[:,i] = 316.22*L_acc[i]/100.0
        except:
            Alpha[:,i] = 316.22*L_acc[i]/100.0
    return np.clip(Alpha, a_min = 1 , a_max = 316.22)/316.22


def WeightedVote(Alpha, L):
    Alpha_not_abstain = torch.tensor(Alpha)*((L != -1).float())

    pred = (Alpha_not_abstain * L.numpy()).sum(axis = 1)/(Alpha_not_abstain.sum(axis = 1))
    
    return torch.tensor(np.nan_to_num(x = pred, nan = 0.5))

def Gen_pseudolabel_not_i(L, Alpha_0, i, labels, threshold = 1, confidence_level = 0.01):
    idx_j = [j for j in range(Alpha_0.shape[1]) if j != i]
    Alpha_not_i = Alpha_0[:, idx_j]
    L_not_i = L[:, idx_j]


    multiple_wl_predict_idx = ((L != -1).float()).sum(axis = 1) > threshold
    pseudolabels_not_i = -1*np.ones_like(labels)

    pred = WeightedVote(Alpha_not_i, L_not_i)
    pseudolabels_not_i[(pred > (0.5 + confidence_level)) * multiple_wl_predict_idx] = 1
    pseudolabels_not_i[(pred < (0.5 - confidence_level)) * multiple_wl_predict_idx] = 0
    return pseudolabels_not_i

def Generate_data_var_reg_cotrain(X,L, labels, pseudolabels_not_i, labeled_inds, i, clip = 1e-5):

    pseudolabels_not_i[labeled_inds] = labels[labeled_inds]
    labeled_inds_i = (pseudolabels_not_i!= -1)
    X_train, y_train, _, _ = Generate_data_var_reg(X,L, pseudolabels_not_i, labeled_inds_i, i, clip = clip)

    return X_train, y_train
