#### Python Script to take a CSV of end_model results and print pseudolabel train results 

import pandas as pd
import numpy as np
import torch
from utils import AdjustAcc, NotAbstainAcc
from snorkel.labeling.model import LabelModel


best_thresh_dict = {'basketball': {'Dongle_alpha_1': 10,
	'Dongle_alpha_oracle': 10,
	'Dongle_alpha_s': 10,
	'Dongle_opt_oracle': 10,
	'Dongle_opt_s': 10,
	'Dongle_reg_alpha': 10,
	'GCN': 10,
	'LPA': 10,
	'LPA_WL': 10,
	'snorkel': 10},
 'cdr': {'Dongle_alpha_1': 1,
	'Dongle_alpha_oracle': 1,
	'Dongle_alpha_s': 1,
	'Dongle_opt_oracle': 1,
	'Dongle_opt_s': 1,
	'Dongle_reg_alpha': 1,
	'GCN': 1,
	'LPA': 1,
	'LPA_WL': 1,
	'snorkel': 1},
 'sms': {'Dongle_alpha_1': 1,
	'Dongle_alpha_oracle': 1,
	'Dongle_alpha_s': 1,
	'Dongle_opt_oracle': 1,
	'Dongle_opt_s': 1,
	'Dongle_reg_alpha': 1,
	'GCN': 1,
	'LPA': 1,
	'LPA_WL': 1,
	'snorkel': 1},
 'tennis': {'Dongle_alpha_1': 100,
	'Dongle_alpha_oracle': 100,
	'Dongle_alpha_s': 100,
	'Dongle_opt_oracle': 100,
	'Dongle_opt_s': 100,
	'Dongle_reg_alpha': 100,
	'GCN': 100,
	'LPA': 100,
	'LPA_WL': 100,
	'snorkel': 100},
 'youtube': {'Dongle_alpha_1': 5,
	'Dongle_alpha_oracle': 5,
	'Dongle_alpha_s': 5,
	'Dongle_opt_oracle': 5,
	'Dongle_opt_s': 5,
	'Dongle_reg_alpha': 5,
	'GCN': 5,
	'LPA': 5,
	'LPA_WL': 5,
	'snorkel': 5}
}

best_k = {
	'youtube': 1,
	'basketball': 0.995,
	'cdr':0.995,
	'sms':0.9975,
	'tennis':0.995,
}

# methods = ['Dongle_reg_alpha', 'Dongle_opt_s', 'Dongle_opt_oracle', 'Dongle_alpha_s',
		# 'Dongle_alpha_oracle','Dongle_alpha_1','LPA', 
		# "snorkel", "LPA_WL", "GCN"]
# methods = ["snorkel", "LPA", "LPA_WL"]
# methods = ["Dongle_alpha_s", "Dongle_opt_s", "Dongle_reg_alpha"]
# methods = ["Dongle_opt_oracle", "Dongle_alpha_1"]
methods = ["Liger"] 
# for dset in ["youtube", "sms", "cdr", "basketball", "tennis"]:
# for dset in ["youtube", "sms", "cdr", "basketball"]:

for dset in ["tennis"]:
	base_path = "datasets/" + dset + "/"
	print("Dataset", dset)
	
	seeds = [0,1,2,3,4]
	for method in methods:
		print("Method", method)
		acs = []
		naacs = []
		covs = []
		for seed in seeds:
			np.random.seed(seed)
			train_labels = torch.load(base_path + "/train_labels_seed" + str(seed)).numpy()
			
			if method == "Liger":
				k = best_k[dset]
				pl_path = "datasets/" + dset + "/dongle/" + "liger" + "/" + "k_" + str(k) + "_seed_" + str(seed) + ".npy"
			else:
				euc_th = best_thresh_dict[dset][method]
				pl_path = "datasets/" + dset + "/dongle/" + method + "/" + "euc_" + str(euc_th) + "_seed_" + str(seed) + ".npy"
			
			pseudolabs = np.load(pl_path)
			
			# getting labeled data for liger (as it wasnt saved for pls)
			if method == "Liger":
				labeled_inds = np.random.choice(range(train_labels.shape[0]), size=100, replace=False)
				pseudolabs[labeled_inds,:] = np.stack((1-train_labels[labeled_inds], train_labels[labeled_inds]), axis = 1)
			
			acs.append(AdjustAcc(pseudolabs, train_labels))
			covs.append(np.sum(np.abs(pseudolabs[:,1] - 0.5) > 0.001) / train_labels.shape[0])
			naacs.append(NotAbstainAcc(pseudolabs, train_labels))
			
		print("Adjust Acc")
		print(np.round(np.mean(acs) * 100, decimals=2), np.round(np.std(acs) / np.sqrt(5) * 100, decimals=2))
		print("\n")
		print("Cov")
		print(np.round(np.mean(covs) * 100, decimals=2), np.round(np.std(covs) / np.sqrt(5) * 100, decimals=2))
		print("\n")
		print("Not Abstain Acc")
		print(np.round(np.mean(naacs) * 100, decimals=2), np.round(np.std(naacs) / np.sqrt(5) * 100, decimals=2))
		print("\n")