## OLD -> see gen_pseudolabels for running LPA methods

import numpy as np
import torch
from torch.utils.data import DataLoader
import warnings 
warnings.filterwarnings('ignore')

from snorkel.labeling.model import LabelModel
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import seaborn as sns

import os.path

sns.set_style("darkgrid")
# import sys
# np.set_printoptions(threshold=sys.maxsize)

from utils import GenerateMatrix, KhopNeighbor,  Acc, normalize_matrix, AdjustAcc, NotAbstainAcc
from label_prop import PropagationSoft, PropagationHard, PropagationAdaptive
from sklearn.metrics import pairwise_distances
# from end_model import run_end_model
import dataset
import argparse

def get_results(soft_preds, labels, con_idx, method_name, split):
	return [method_name, AdjustAcc(soft_preds, labels),  \
			NotAbstainAcc(soft_preds, labels), (np.abs(soft_preds[:,1] - 0.5) > 0.001).sum() / labels.shape[0],
			Acc(soft_preds[con_idx], labels[con_idx]), con_idx.shape[0] / labels.shape[0], split]


def load_data(data_name, euc_th, wl_th, seed):
	'''' 1. Load training data from saved checkpoints
		 2. Construct adjacency matrix / load adjacency matrix'''

	# Load raw data
	np.random.seed(seed)
	base_path = "datasets/" + data_name + "/"
	if os.path.isfile(base_path + "/train_X_seed" + str(seed)): 
		train_X = torch.load(base_path + "/train_X_seed" + str(seed))
		train_L = torch.load(base_path + "/train_L_seed" + str(seed))
		train_labels = torch.load(base_path + "/train_labels_seed" + str(seed))
	else:
		train_dataset = dataset.WSDataset(dataset=data_name, split="train", feature="bert", balance=True, seed = seed)
		train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)

		train_X, train_L, train_labels = next(iter(train_loader))
		train_X, train_L = train_X.type(torch.float), train_L.type(torch.long)
		torch.save(train_X, base_path + "/train_X_seed" + str(seed))
		torch.save(train_L, base_path + "/train_L_seed" + str(seed))
		torch.save(train_labels, base_path + "/train_labels_seed" + str(seed))
		
	X = train_X
	L = train_L
	labels = train_labels

	# Load adjacency matrix
	euc_mat = pairwise_distances(X, Y = None, metric='euclidean') 

	base_path = "datasets/" + data_name + "/"
	if os.path.isfile(base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(euc_th)): 
		W_x = torch.load(base_path + "/W_x_seed" + str(seed) + "_thresh_" + str(euc_th))
		S_x = torch.load(base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(euc_th))
	else:
		W_x, S_x = GenerateMatrix(euc_mat, thresh = euc_th)
		torch.save(W_x, base_path + "/W_x_seed" + str(seed) + "_thresh_" + str(euc_th))
		torch.save(S_x, base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(euc_th))
	
	if os.path.isfile(base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(wl_th)): 
		W_x_large = torch.load(base_path + "/W_x_seed" + str(seed) + "_thresh_" + str(wl_th))
		S_x_large = torch.load(base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(wl_th))
	else:
		W_x_large, S_x_large = GenerateMatrix(euc_mat, thresh = wl_th)
		torch.save(W_x_large, base_path + "/W_x_seed" + str(seed) + "_thresh_" + str(wl_th))
		torch.save(S_x_large, base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(wl_th))

	return X, L, labels, W_x, S_x, W_x_large, S_x_large

def get_auggraph(L, labels, S_x_large, a):
	# Weak labels aggregation
	label_model = LabelModel(cardinality=2, verbose=False)
	label_model.fit(L_train=L, n_epochs=200, log_freq=200)
	pseudolabs_soft = label_model.predict_proba(L=L)
	base_preds = np.array([pseudolabs_soft])

	# Smooth weak labels
	smooth_wl = PropagationSoft(base_preds, S_x_large , labels, labeled_inds = [], alpha = 10)
	abs_th = np.quantile(np.max(smooth_wl , axis = 1), a)

	# Construct a new graph from wl partitions
	sim_threshold = 0.01
	L_tf = (-1)* (L == 0) + (L == 1)
	cos_mat = (pairwise_distances(L_tf, Y=None, metric='cosine') < sim_threshold)*1.0
	W_wl = cos_mat.copy()

	# Remove points that weak labels are not consistent, abstain.
	# Confident wl
	unconwl_idx = np.max(smooth_wl , axis = 1) < abs_th
	unconwl_idx_array = np.array([i for i in range(len(unconwl_idx)) if unconwl_idx[i] == True])
	# Abstain weak labels
	abstain_wl_idx = (L.sum(axis=1) == -1 * L.shape[1])
	abstain_wl_idx_array = np.array([i for i in range(len(abstain_wl_idx)) if abstain_wl_idx[i] == True])

	if abstain_wl_idx_array != []:
		W_wl[abstain_wl_idx_array,:] = 0
		W_wl[:,abstain_wl_idx_array] = 0
	if unconwl_idx_array != []:
		W_wl[unconwl_idx_array,:] = 0
		W_wl[:,unconwl_idx_array] = 0

	return W_wl, base_preds

def experiment(data_name = 'youtube',  num_labels = 100, euc_th = 10, wl_th = 100, a = 0.6, mu = 0.01, lamb = 0.001):
	### LOAD DATA
	results = []
	seeds = range(5)
	
	for seed in seeds:
		### Load data
		np.random.seed(seed)
		X, L, labels, W_x, S_x, W_x_large, S_x_large = load_data(data_name = data_name, euc_th = euc_th, wl_th = wl_th, seed = seed)
		W_wl, base_preds = get_auggraph(L, labels, S_x_large, a)
		W_x_wl = W_x + mu * W_wl
		S_x_wl = normalize_matrix(W_x_wl)

		### Label propagation
		labeled_inds = np.random.choice(range(len(X)), size=num_labels, replace=False)
		con_idx  = KhopNeighbor(labeled_inds, S_x, 1000) # get all connected
		con_idx2 = KhopNeighbor(labeled_inds, S_x_wl, 1000) # get all connected 
		alpha = 1/lamb

		# WL
		base_preds_wlb = base_preds[0].copy()
		base_preds_wlb[labeled_inds,:] = np.stack((1-labels[labeled_inds], labels[labeled_inds]), axis = 1)
		results.append(['wl', AdjustAcc(base_preds_wlb, labels), \
						NotAbstainAcc(base_preds_wlb, labels), \
						(base_preds_wlb[:,1]!= 0.5).sum() / labels.shape[0], 0, 0, "train"])
		# LP
		f_baseline = PropagationHard(np.ones_like(base_preds)*0.5, S_x, labels, labeled_inds, alpha = alpha)
		results.append(get_results(f_baseline, labels, con_idx, "lp", "train"))
		# LP + WL
		f_lp_wl = PropagationHard(base_preds, S_x, labels, labeled_inds, alpha = alpha)
		results.append(get_results(f_lp_wl, labels, con_idx, "lp+wl", "train"))
		# LPAG
		f_lp_ag = PropagationHard(np.ones_like(base_preds)*0.5, S_x_wl, labels, labeled_inds, alpha = alpha)
		results.append(get_results(f_lp_ag, labels, con_idx2, "lp+ag", "train"))
		# LPAG + WL
		f_lp_ag_wl = PropagationHard(base_preds, S_x_wl, labels, labeled_inds, alpha = alpha)
		results.append(get_results(f_lp_ag_wl, labels, con_idx2, "lp+ag+wl", "train"))
	return results

def exp_wrapper(data_name,  num_labels = 100, euc_th = 10, wl_th = 100, a = 0.2, mu = 0.01, lamb = 0.001):
	result = experiment(data_name = data_name,  num_labels = num_labels, euc_th = euc_th, wl_th = wl_th, a = a, mu = mu, lamb = lamb)
	df = pd.DataFrame(result, columns= ['Method','AdjustedAcc', 'NonAbstainAcc', 'Cov', 'ConnectedAcc', 'ConnectedCov', "split"])
	
	df_train = df[df.split == "train"]
	print({'data_name' : data_name,  'num_labels' : num_labels, 'euc_th' : euc_th, 'wl_th' : wl_th, 'a' : a, 'mu' : mu, 'lamb' : lamb})
	print("Train:")
	print(100*df_train.groupby('Method').mean())


	sns.barplot(data = df_train, x = 'Method', y ='AdjustedAcc' ,palette = "Set2")
	plt.ylim(0.5, 0.95)
	plt.title( "Train:" + data_name )
	plt.show()

if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	parser.add_argument('--dataset', default="youtube", type=str, help="Dataset to run (spam, agnews, yelp, awa2)")
	parser.add_argument('--method', default="lpag", type=str, help="Which method to load pseudolabels from")
	parser.add_argument('--euc_th', default=2, type=int, help="Euclidean distance threshold")
	parser.add_argument('--wl_th', default=20, type=int, help="Weak Labeler partition threshold")
	parser.add_argument('--a', default=0, type=float, help="Abstain threshold")
	parser.add_argument('--mu', default=0.001, type=float, help="Weighting of weak labeler graph")
	parser.add_argument('--lamb', default=0.001, type=float, help="Prior regularization coefficient")
	parser.add_argument('--num_labels', default=100, type=int, help="Number of initial labels coefficient")    
	args = parser.parse_args()

	exp_wrapper(args.dataset,  num_labels = args.num_labels, euc_th = args.euc_th, wl_th = args.wl_th, a = args.a, mu = args.mu, lamb = args.lamb)
