import os
import torch

import numpy as np

from liger import Liger, Flyingsquid_Cluster
from liger_core import load_config
from liger_utils import evaluate_thresholds, cluster_embeddings

import sys
import argparse
import warnings
if not sys.warnoptions:
	warnings.simplefilter("ignore")

import dataset
from torch.utils.data import DataLoader

from end_model import run_end_model

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset

import pickle


def convert_votes(votes):
	'''
	Helper function to convert abstain vote as -1 to 0
	'''

	new_votes = np.zeros_like(votes)
	for i in range(votes.shape[0]):
		for j in range(votes.shape[1]):
			if votes[i, j] == 1:
				new_votes[i, j] = 1
			elif votes[i,j] == 0:
				new_votes[i, j] = -1
			elif votes[i,j] == -1:
				new_votes[i, j] = 0
	
	return new_votes

def get_cluster_pls(L_data, labels, cluster_labels, cluster_models, best_cbs):
	'''
	Function to extract pseudolabels for a given dataset
	'''

	pl_return = np.zeros((len(L_data), 2))

	for cluster_idx, FS_cluster in enumerate(cluster_models):
		points_in_cluster = np.argwhere(cluster_labels == cluster_idx)
		L_train_cluster = L_data[points_in_cluster]

		best_model = FS_cluster.triplet_models[best_cbs[cluster_idx]]		
		pseudolabels = best_model.predict_proba(L_train_cluster.squeeze())
		
		for i, index in enumerate(points_in_cluster):
			pl_return[index] = pseudolabels[i]


	return pl_return	

def liger_em():

	parser = argparse.ArgumentParser()	
	parser.add_argument('--dataset', default="youtube", type=str, help="Dataset to run (spam, agnews, yelp, awa2)")
	args = parser.parse_args()

	res = []

	for k in [0.995, 0.9975, 1]:
		for seed in range(5):

			# setting seeds
			torch.manual_seed(seed)
			torch.cuda.manual_seed(seed)
			torch.backends.cudnn.deterministic = True
			torch.backends.cudnn.benchmark = False
			np.random.seed(seed)

			train_data = torch.load("datasets/" + args.dataset + "/train_X_seed" + str(seed))
			train_labels = torch.load("datasets/" + args.dataset + "/train_labels_seed" + str(seed)).numpy()

			# randomly getting labeled inds for liger
			labeled_inds = np.random.choice(range(train_labels.shape[0]), size=100, replace=False)

			pl_path = "datasets/" + args.dataset + "/dongle/" + "liger" + "/" + "k_" + str(k) + "_seed_" + str(seed) + ".npy"
			pseudolabs = np.load(pl_path)
			pseudolabs[labeled_inds,:] = np.stack((1-train_labels[labeled_inds], train_labels[labeled_inds]), axis = 1)
			
			# filtering out points pseudolabels abstain on
			valid_inds = np.abs(pseudolabs[:, 0] - 0.5) > 0.001			
			# print("Coverage", np.sum(valid_inds) / len(pseudolabs))
					
			train_data = train_data[valid_inds]
			pseudolabs = pseudolabs[valid_inds]
			train_labels = train_labels[valid_inds]
					
			hard_pseudolabs = np.argmax(pseudolabs, axis=1)

			if args.dataset == "basketball" or args.dataset == "tennis":
				# _, val_dataset, test_dataset = get_dataset(args.dataset, feature=None)
				val_dataset = dataset.WSDataset(dataset=args.dataset, split="val", feature=None, balance=False, seed = seed)
				test_dataset = dataset.WSDataset(dataset=args.dataset, split="test", feature=None, balance=False, seed = seed)
			else:
				# _, val_dataset, test_dataset = get_dataset(args.dataset, feature="bert")
				val_dataset = dataset.WSDataset(dataset=args.dataset, split="val", feature="bert", balance=False, seed = seed)
				test_dataset = dataset.WSDataset(dataset=args.dataset, split="test", feature="bert", balance=False, seed = seed)

			# getting data / labels seperately
			val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
			val_data, _, val_labels = next(iter(val_loader))
			test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
			test_data, _, test_labels = next(iter(test_loader))

			best_val, best_test = run_end_model(train_data, torch.tensor(pseudolabs, dtype=torch.float), val_data, val_labels, test_data, test_labels, soft=True)
			res.append(best_test)
			path = "results/" + args.dataset + "_liger_results"
			path += "_soft.txt"
			with open(path, "a") as f:
				res = [args.dataset, seed, best_val, best_test, k]
				f.write(",".join([str(x) for x in res]))
				f.write("\n")

def main():
	
	parser = argparse.ArgumentParser()	
	parser.add_argument('--dataset', default="youtube", type=str, help="Dataset to run (spam, agnews, yelp, awa2)")
	args = parser.parse_args()

	n_clusters = {
		"youtube": 2,
		"tennis": 2,
		"basketball": 2,
		"cdr": 2,
		"sms": 2,
	}

	ks = [0.995, 0.9975, 1]
	print("Dataset", args.dataset)


	n_clusters = n_clusters[args.dataset]
	base_path = "datasets/" + args.dataset + "/"
	neg_balances_to_try = np.arange(.01, .99, .01)
	
	for k in ks:

		thresholds = {
			"youtube": k * np.ones(10),
			"tennis": k * np.ones(6),
			"basketball": k * np.ones(4),
			"sms": k * np.ones(73),
			"cdr": k * np.ones(33),
		}

		thresholds = thresholds[args.dataset]


		T = 1
		tune_by = 'acc'

		for seed in range(5):


			# setting seeds
			torch.manual_seed(seed)
			torch.cuda.manual_seed(seed)
			torch.backends.cudnn.deterministic = True
			torch.backends.cudnn.benchmark = False
			np.random.seed(seed)

			avg_embeddings_train = torch.load(base_path + "/train_X_seed" + str(seed)).numpy().astype(float)
			L_train_raw_orig = torch.load(base_path + "/train_L_seed" + str(seed)).numpy().astype(float)
			Y_train_raw = torch.load(base_path + "/train_labels_seed" + str(seed)).numpy()

			# randomly getting labeled inds for liger
			labeled_inds = np.random.choice(range(Y_train_raw.shape[0]), size=100, replace=False)

			if args.dataset == "basketball" or args.dataset == "tennis":
				# _, val_dataset, test_dataset = get_dataset(args.dataset, feature=None)
				val_dataset = dataset.WSDataset(dataset=args.dataset, split="val", feature=None, balance=False, seed = seed)
				test_dataset = dataset.WSDataset(dataset=args.dataset, split="test", feature=None, balance=False, seed = seed)
			else:
				# _, val_dataset, test_dataset = get_dataset(args.dataset, feature="bert")
				val_dataset = dataset.WSDataset(dataset=args.dataset, split="val", feature="bert", balance=False, seed = seed)
				test_dataset = dataset.WSDataset(dataset=args.dataset, split="test", feature="bert", balance=False, seed = seed)

			# getting val and test data
			val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
			test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
			
			avg_embeddings_dev, L_dev_raw_orig, Y_dev_raw = next(iter(val_loader))
			avg_embeddings_test, L_test_raw_orig, Y_test_raw = next(iter(test_loader))

			avg_embeddings_dev = avg_embeddings_dev.numpy().astype(float)
			avg_embeddings_test = avg_embeddings_test.numpy().astype(float)

			L_dev_raw_orig = L_dev_raw_orig.numpy().astype(float)
			Y_dev_raw = Y_dev_raw.numpy()

			L_test_raw_orig = L_test_raw_orig.numpy().astype(float)
			Y_test_raw = Y_test_raw.numpy()	

			# converting votes
			L_train_raw_orig = convert_votes(L_train_raw_orig)
			L_dev_raw_orig = convert_votes(L_dev_raw_orig)
			L_test_raw_orig = convert_votes(L_test_raw_orig)

			# converting labels
			Y_train_raw = np.array([1 if pred > 0.5 else -1 for pred in Y_train_raw])
			Y_dev_raw = np.array([1 if pred > 0.5 else -1 for pred in Y_dev_raw])
			Y_test_raw = np.array([1 if pred > 0.5 else -1 for pred in Y_test_raw])

			if args.dataset == "basketball" or args.dataset == "tennis":
				L_train_raw = L_train_raw_orig
				L_dev_raw = L_dev_raw_orig
				L_test_raw = L_test_raw_orig
			else:
				liger = Liger()
				L_train_expanded = liger.expand_lfs(
					L_train_raw_orig, L_train_raw_orig, avg_embeddings_train, avg_embeddings_train,
					thresholds = thresholds)
				L_dev_expanded = liger.expand_lfs(
					L_train_raw_orig, L_dev_raw_orig, avg_embeddings_train, avg_embeddings_dev,
					thresholds = thresholds)
				L_test_expanded = liger.expand_lfs(
					L_train_raw_orig, L_test_raw_orig, avg_embeddings_train, avg_embeddings_test,
					thresholds = thresholds)

				L_train_raw = L_train_expanded
				L_dev_raw = L_dev_expanded
				L_test_raw = L_test_expanded

			L_train = L_train_raw[:L_train_raw.shape[0] - (L_train_raw.shape[0] % T)]
			L_dev = L_dev_raw[:L_dev_raw.shape[0] - (L_dev_raw.shape[0] % T)]
			L_test = L_test_raw[:L_test_raw.shape[0] - (L_test_raw.shape[0] % T)]
			Y_dev = Y_dev_raw[:Y_dev_raw.shape[0] - (Y_dev_raw.shape[0] % T)]
			Y_test = Y_test_raw[:Y_test_raw.shape[0] - (Y_test_raw.shape[0] % T)]


			m_per_task = L_train.shape[1]

			m = T * m_per_task
			v = T
			
			kmeans, embedding_groups, train_cluster_labels = cluster_embeddings(avg_embeddings_train, n_clusters)
			dev_cluster_labels = kmeans.predict(avg_embeddings_dev)
			test_cluster_labels = kmeans.predict(avg_embeddings_test) 
			cluster_models = []

			for i in range(len(embedding_groups)):
				cluster_models.append(Flyingsquid_Cluster(X=embedding_groups[i], mu=kmeans.cluster_centers_[i], T=T, m_per_task=m_per_task))
			

			evaluate_thresholds(thresholds, cluster_models, neg_balances_to_try, \
				L_train_raw, L_dev_raw, L_test_raw, \
				Y_dev_raw, Y_test_raw, train_cluster_labels, dev_cluster_labels, test_cluster_labels,\
				evaluate_test=False, tune_test=False, tune_by=tune_by)
			
			best_cbs = []
			for x in cluster_models:
				best_cbs.append(x.best_cb)
			

			# print("Best CB", best_cbs)
			# best_cbs = [0.26, 0.36]
			pls = get_cluster_pls(L_train_raw, Y_train_raw, train_cluster_labels, cluster_models, best_cbs)
			print("K", k, "Seed", seed)
			print(pls.shape)
			# print("PL accuracy", np.mean(pls == Y_train_raw))
			np.save("datasets/" + args.dataset + "/dongle/" + "liger" + "/" +  "k_" + str(k) + "_seed_" + str(seed), pls)


if __name__ == '__main__':
	# main()
	liger_em()
