import numpy as np
import torch
from torch.utils.data import DataLoader
import warnings 
warnings.filterwarnings('ignore')

from snorkel.labeling.model import LabelModel
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.nn import Softmax
import pandas as pd
import seaborn as sns
import sys

import os.path
sns.set_style("darkgrid")


from sklearn.metrics import pairwise_distances
import dataset
import argparse

from utils import Acc, NotAbstainAcc, AdjustAcc, Flip_L, CheckLFs_Acc, Snorkel, GetStats
from label_prop import PropagationSoft, PropagationHard
from snorkel.labeling.model import LabelModel
from extension import alpha_from_LPA, LPA_with_dongle_with_labeled_inds_custom_alpha, LPA_with_dongle_with_labeled_inds
from extension import Adaboost_weight_norm, Adaboost_weight
from extension import Generate_data_var_reg, alpha_from_reg
from utils import GenerateMatrix
from gnn import eval, adjmat_to_graph

def load_data(data_name, euc_th, seed):

	'''
	Function to a saved split of data + graph information

	Args:
	data_name - string representing the dataset
	euc_th - value representing the thresholding on the euclidean radius graph
	wl_th - value representing the threshold on a graph to smooth weak labelers (for ALPA)
	seed - random seed number
	'''

	# Load raw data
	base_path = "datasets/" + data_name + "/"
	if os.path.isfile(base_path + "/train_X_seed" + str(seed)): 
		train_X = torch.load(base_path + "/train_X_seed" + str(seed))
		train_L = torch.load(base_path + "/train_L_seed" + str(seed))
		train_labels = torch.load(base_path + "/train_labels_seed" + str(seed))
	else:
		train_dataset = dataset.WSDataset(dataset=data_name, split="train", feature="bert", balance=True, seed = seed)
		train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)

		train_X, train_L, train_labels = next(iter(train_loader))
		train_X, train_L = train_X.type(torch.float), train_L.type(torch.long)
		torch.save(train_X, base_path + "/train_X_seed" + str(seed))
		torch.save(train_L, base_path + "/train_L_seed" + str(seed))
		torch.save(train_labels, base_path + "/train_labels_seed" + str(seed))
		
	X = train_X
	L = train_L
	labels = train_labels

	# Load adjacency matrix
	euc_mat = pairwise_distances(X, Y = None, metric='euclidean') 

	base_path = "datasets/" + data_name + "/"
	if os.path.isfile(base_path + "/S_x_seed" + str(seed) + "_thresh_" + str(euc_th)): 
		W_x = torch.load(base_path + "/W_x_seed" + str(seed) + "_thresh_" + str(euc_th))
	else:
		W_x, _ = GenerateMatrix(euc_mat, thresh = euc_th)
		torch.save(W_x, base_path + "/W_x_seed" + str(seed) + "_thresh_" + str(euc_th))

	return X, L, labels, W_x

def gen_pl(data_name, num_labels, lamb, seed, euc_th = 10, wl_th = 10):
	
	'''
	Wrapper script to run experiments given a dataset 
	-> *** saves outputs of algo before training end model

	Args:
	data_name - dataset
	num_labels - amount of labeled data (for all of our experiments we use 100)
	euc_th - threshold for euclidean graph for LPA
	wl_th - threshold for smoothing graph for ALPA
	thresh - ... # FILL THIS IN
	'''

	results = []
	X, L, labels, W_x = load_data(data_name, euc_th, seed)
	
	# Accuracy approx
	L_acc, snorkel_pred = Snorkel(L)
	L_acc_oracle, Coverage = CheckLFs_Acc(L, labels, show = False)
	L_acc_oracle = np.nan_to_num(L_acc_oracle)

	# labeled points
	labeled_inds = np.random.choice(range(W_x.shape[0]), size= num_labels, replace=False)
	snorkel_pred[labeled_inds,:] = np.stack((1-labels[labeled_inds], labels[labeled_inds]), axis = 1)

	# LPA + WL
	LPA_WL = PropagationHard(snorkel_pred, W_x, labels, labeled_inds, alpha = 1)

	# Baseline
	results.append(list(GetStats(snorkel_pred, labels)) + ['Snorkel', data_name, num_labels, lamb, seed, euc_th])
	results.append(list(GetStats(LPA_WL, labels)) + ['LPA + WL', data_name, num_labels, lamb, seed, euc_th])
	
	np.save("datasets/" + data_name + "/dongle/" + "snorkel" + "/" + "euc_" + str(euc_th) + "_seed_" + str(seed), snorkel_pred)
	np.save("datasets/" + data_name + "/dongle/" + "LPA_WL" + "/" + "euc_" + str(euc_th) + "_seed_" + str(seed), LPA_WL)

	# LPA with dongle nodes
	alpha_s = L_acc
	alpha_oracle = L_acc_oracle
	alpha_one = np.ones_like(L_acc)
	alpha_zero = np.zeros_like(L_acc)

	list_alpha = [alpha_s, alpha_oracle, alpha_one, alpha_zero]
	list_method1 = ['Dongle_alpha_s','Dongle_alpha_oracle','Dongle_alpha_1','LPA']

	for alpha_j, Methods in zip(list_alpha, list_method1):
		pseudolabels = LPA_with_dongle_with_labeled_inds(W_x, L, alpha_j, labels, labeled_inds, lamb = lamb)
		results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])

		np.save("datasets/" + data_name + "/dongle/" + Methods + "/" + "euc_" + str(euc_th) + "_seed_" + str(seed), pseudolabels)


	# Optimal weight
	opt_s = Adaboost_weight_norm(L_acc/100, clip = 5)
	opt_oracle = Adaboost_weight_norm(L_acc_oracle/100, clip = 5)
	list_alpha2 = [opt_s, opt_oracle]
	list_method2 = ['Dongle_opt_s', 'Dongle_opt_oracle']

	for alpha_j, Methods in zip(list_alpha2, list_method2):
		pseudolabels = LPA_with_dongle_with_labeled_inds(W_x, L, alpha_j, labels, labeled_inds, lamb = lamb)
		results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])

		np.save("datasets/" + data_name + "/dongle/" + Methods + "/" +  "euc_" + str(euc_th) + "_seed_" + str(seed), pseudolabels)


	# Alpha_j depends on x
	# alpha_mat_lpa = alpha_from_LPA(X,L,labels, L_acc, labeled_inds, thresh = thresh)
	alpha_mat_reg = alpha_from_reg(X, L, labels, L_acc, labeled_inds)

	list_alpha_mat = [alpha_mat_reg]
	list_method3 = ['Dongle_reg_alpha']

	for alpha_mat, Methods in zip(list_alpha_mat, list_method3):
		pseudolabels = LPA_with_dongle_with_labeled_inds_custom_alpha(W_x, L, alpha_mat, labels, labeled_inds, lamb = 1)
		# saving pseudolabels
		np.save("datasets/" + data_name + "/dongle/" + Methods + "/" + "euc_" + str(euc_th) + "_seed_" + str(seed), pseudolabels)
		results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])

	# running gcn
	g = adjmat_to_graph(W_x, X, labels, labeled_inds)
	acc, final_pred = eval(g)
	softmax = Softmax()
	with torch.no_grad():
		pseudolabels = softmax(final_pred).detach().numpy()
	np.save("datasets/" + data_name + "/dongle/GCN/" + "euc_" + str(euc_th) + "_seed_" + str(seed), pseudolabels)
	results.append(list(GetStats(pseudolabels, labels)) + [Methods, data_name, num_labels, lamb, seed, euc_th])


	return results  


if __name__ == "__main__":
	
	parser = argparse.ArgumentParser()
	parser.add_argument('--dataset', default="youtube", type=str, help="Dataset to run (spam, agnews, yelp, awa2)")
	args = parser.parse_args()

	# iterating and producing pseudolabels for all hyperparameter values
	for euc_th in [1, 2, 5, 10, 100]:
		# for wl_th in [10, 100]:
		# for thresh in [1, 10, 100]:
		for seed in range(5):
			print("Seed: ", seed, "Euc: ", euc_th)
			gen_pl(args.dataset, num_labels=100, lamb = 1, seed=seed, euc_th = euc_th)