"""
Debiasing as Nonlinear Integer Programming
Usage: 
	python dnip.py \
		-c config/default_params.json \
		--vec_dir vectors/gpt2-xl/pubmedqa_gpt2-xl_shot1_seed0
"""

import argparse
import math
import random
import time
import numpy as np
import os
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from collections import Counter
from itertools import combinations

from config import Config
import re
import json
import ast
import pickle
import datasets
from datasets import load_dataset


def sample_subset_by_lbl(vec_dir, num_class, sentences, labels, preds_not_used, num_samples=None, rseed=0):
	if num_samples is not None:
		if num_samples <= 100:
			np.random.seed(rseed)
			all_class_samples_inds = {k: [] for k in range(num_class)}
			for i in range(len(labels)):
				all_class_samples_inds[labels[i]].append(i)

			all_selected_sentences = []
			all_selected_labels = []
			all_selected_preds_not_used = []
			# selected_class_samples_inds = {k: [] for k in range(num_class)}
			num_sample_class = {k: int(num_samples/num_class) for k in range(num_class-1)}
			num_sample_class[num_class-1] = num_samples - int(num_samples/num_class)*(num_class-1)
			for k in all_class_samples_inds:
				inds = np.random.choice(all_class_samples_inds[k], size=num_sample_class[k], replace=False)
				selected_sentences = [sentences[i] for i in inds]
				all_selected_sentences.extend(selected_sentences)
				selected_labels = [labels[i] for i in inds]
				all_selected_labels.extend(selected_labels)
				selected_preds_not_used = [preds_not_used[i] for i in inds]
				all_selected_preds_not_used.extend(selected_preds_not_used)

			random.seed(rseed)
			combined_lists = list(zip(all_selected_sentences, all_selected_labels, all_selected_preds_not_used))
			random.shuffle(combined_lists)
			final_selected_sentences, final_selected_labels, final_selected_preds_not_used = zip(*combined_lists)
			assert len(final_selected_sentences) == num_samples
		else:
			np.random.seed(rseed)
			inds = np.random.choice(len(labels), size=num_samples, replace=False)
			final_selected_sentences = [sentences[i] for i in inds]
			final_selected_labels = [labels[i] for i in inds]
			final_selected_preds_not_used = [preds_not_used[i] for i in inds]
		
		# save to files
		res_dir = os.path.join(vec_dir, f"train_{num_samples}")
		if not os.path.exists(res_dir):
			 os.makedirs(res_dir)
		if res_dir is not None:
			res_fp = os.path.join(res_dir, os.path.join("train.txt"))
			if not os.path.exists(res_fp):
				save_selected_vecs_txt(res_fp, final_selected_sentences, final_selected_labels, final_selected_preds_not_used)

	else:
		final_selected_sentences, final_selected_labels, final_selected_preds_not_used = sentences, labels, preds_not_used

	return final_selected_sentences, final_selected_labels, final_selected_preds_not_used




def fix_seed_rerun(rseed):
	""" Enable reproducibility """
	np.random.seed(rseed)
	random.seed(rseed)
	os.environ['PYTHONHASHSEED'] = str(rseed)


def lblInfer(y, pp, w):
	# Correct per-class probabilities
	# where y is the weight index list for each class, e.g., [0,1,8,3,4], 
	# and y[n] is the selected weight index for the nth class token
	weights = [w[idx] for idx in y]
	pred_corrected = [np.argmax(np.multiply(v, weights)) for v in pp]
	return pred_corrected

# todo: remove nTrue, rTrue in args
def ObjFunction(pred_corrected, lbl, B, nTrue, rTrue, beta, tau, k, num_class):
	N = len(lbl)
	z1 = 0
	z2 = 0
	z3 = 0
	nError = [0] * num_class    # number of class wrong predictions
	nPred = [0] * num_class     # number of class predictions
	for m in range(N):
		nPred[pred_corrected[m]] += 1
		if lbl[m] != pred_corrected[m]:
			nError[lbl[m]] = nError[lbl[m]] + 1
			z1 += 1
	for j in range(num_class):
		nTrue[j] = B[j] - nError[j]
		rTrue[j] = nTrue[j] / B[j]
		# pmi between predicted class label and actual label 
		# apply add-k smoothing to smooth out nTrue[j] = 0
		rTrue_j_smoothed = (nTrue[j]+k) / (N+k*num_class*num_class)
		rPred_j_smoothed = (nPred[j]+k) / (N+k*num_class)
		rB_j_smoothed = (B[j]+k) / (N+k*num_class)
		z3 += -np.log(rTrue_j_smoothed / (rPred_j_smoothed*rB_j_smoothed))
	z2_combs = 0
	for i in range(num_class-1):
		for j in range(i + 1, num_class):
			z2 = z2 + abs(rTrue[i] -  rTrue[j])
			z2_combs +=1
	z2 /= z2_combs
	z = z1 + beta * N*z2 + tau * N*z3
	return z


def compute_bias(class_acc):
	num_classes = len(class_acc)
	bias = sum(abs(class_acc[i] - class_acc[j]) for i in range(num_classes-1) for j in range(i + 1, num_classes))
	bias /= len(list(combinations(range(num_classes), 2)))
	return bias


def compute_pmi(class_acc, pred_test, B, N, k=1):
	z3 = 0
	z3_pmi = []
	num_class = len(class_acc)
	nPred = [pred_test.count(i) for i in range(num_class)]
	for j in range(num_class):
		nTrue_j = B[j]*class_acc[j]
		# pmi between predicted class label and actual label 
		# apply add-k smoothing to smooth out nTrue[j] = 0
		rTrue_j_smoothed = (nTrue_j+k) / (N+k*num_class*num_class)
		rPred_j_smoothed = (nPred[j]+k) / (N+k*num_class)
		rB_j_smoothed = (B[j]+k) / (N+k*num_class)
		pmi_j = np.log(rTrue_j_smoothed / (rPred_j_smoothed*rB_j_smoothed))
		z3 += -pmi_j
		z3_pmi.append(pmi_j)
	return z3, z3_pmi

def save_to_file(out_fp, res, mode='a'):
	with open(out_fp, mode) as f:
		f.write(res + '\n')

# Main
def main(config):
	# k = config.k
	# beta = config.beta
	# tau = config.beta
	# num_w = config.num_w
	# rseed = config.rseed

	ks = [4000]
	betas = [2.7, 3]
	taus = [0.2]
	num_ws = [30, 50, 70, 90]
	rseed = 1
	# vary train size
	num_samples = ['full', 10, 50, 100, 500, 1000]

	# select hyperparameters based on acc. on a dev set
	for num_sample in num_samples:
		# Read in raw data
		lbl_raw = []
		# Predicted per-class probabilities (N-dim)
		pp_raw = []
		# Predictions
		pred_raw = []
		# Read labels and predicted token likelihoods
		file_path = os.path.join(config.vec_dir, 'opt.txt')
		ds = config.vec_dir.split('/')[-1].split('_')[0]
		print('==ds==', ds)
		if ds == 'pubmedqa' and num_sample == 1000:
			break

		with open(file_path, 'r') as f:
			line = f.readline()
		num_class = len(line.strip().split()[3:])
		print('num_class ', num_class)

		with open(file_path, 'r') as file:
			for line in file:
				parts = line.strip().split()
				cur_lbl = int(parts[1])
				lbl_raw.append(cur_lbl) 
				cur_pred = int(parts[2])
				pred_raw.append(cur_pred) 
				p_vec = [float(parts[n + 3]) for n in range(num_class)]
				pp_raw.append(p_vec)

		# split raw train set into train and dev by 0.95/0.05
		np.random.seed(rseed)
		train_inds = np.random.choice(len(lbl_raw), size=int(0.95*len(lbl_raw)), replace=False)
		pp = [pp_raw[i] for i in train_inds]
		lbl = [lbl_raw[i] for i in train_inds]
		preds_not_used = [pred_raw[i] for i in train_inds]

		dev_inds = [x for x in range(len(lbl_raw)) if x not in train_inds]
		pp_dev = [pp_raw[i] for i in dev_inds]
		lbl_dev = [lbl_raw[i] for i in dev_inds]
		B_dict_dev = Counter(lbl_dev)
		B_dev = [B_dict_dev[x] for x in sorted(B_dict_dev.keys())]
		# print('Dev support:', B_dev)
		print(type(num_sample))
		if num_sample == 'full':
			print(f"===running on full train set, {len(lbl)} train samples===")
		elif num_sample > len(pp):
			print(f"==={num_sample} is greater than the size of the input dataset ({len(pp)})! Skipping...")
			break
		else:
			if num_sample == 10 and ds == 'dbpedia':
				num_sample = 15 # dbpedia: 14 classes
			pp, lbl, _ = sample_subset_by_lbl(config.vec_dir, num_class, pp, lbl, preds_not_used, num_sample, rseed)
			print(f"===Using {config.vec_dir} {len(pp)} train samples===")

		# get support
		B_dict = Counter(lbl)
		B = [B_dict[x] if x in B_dict else 0 for x in range(num_class)]
		print('Train support:', B)
		
		for beta in betas:
			for num_w in num_ws:
				for k in ks:
					for tau in taus:
						# select hyperparameters based on acc. on a dev set
						print('===train_size: {} w: {} beta: {} tau: {} k: {} starts==='.format(num_sample, num_w, beta, tau, k))
						
						fix_seed_rerun(rseed)
						start_time = time.time()

						# Weight scale
						w = [(i+1) / num_w for i in range(num_w)]

						# Class-wise weight index
						default_w = len(w)-1
						y = [default_w for _ in range(num_class)]

						# Load SA hyperparameters from config
						T_min = config.T_min
						r_temp = config.r_temp
						iter_min = config.iter_min
						iter_max = config.iter_max
						n_out_loop = config.n_out_loop
						n_in_loop = config.n_in_loop
						low_tem = config.low_tem

						nTrue = [0] * num_class
						rTrue = [0] * num_class
						header_format = "{:^10} {:^8} {:^10} {:^10} {:^8} {:^8} {:^12} {:^12} {:^12} {:^12} "
						header = header_format.format("Iter.", "Temp", "Accept rate", "Accept sol.", "Num of sol.",  "Avg. z","Min z", "Max z","Total run time","Iter duration")
						print(header)

						# Run simulated annealing
						# Weight y is the weight index for classes, e.g., [9, 9, 9, 9, 9]
						y_best = y.copy()
						y_cur = y.copy()
						pred = lblInfer(y_best, pp, w)
						z_cur = ObjFunction(pred, lbl, B, nTrue, rTrue, beta, tau, k, num_class)
						z_best = z_cur
						for T in range(n_out_loop):
							start_iteration_time = time.time()
							z_total = 0
							z_max = -np.inf
							z_min =  np.inf
							n_generate=0
							n_accept=0
							for mk in range(n_in_loop):
								y_new = y_cur.copy()
								# start by randomly selecting a to-be-corrected class, denoted as ii
								ii = random.randint(0, num_class-1)
								# jj is a randomly initialized weight index from the weight scale w, in the range of len(w)
								# core idea of SA: replace class ii's weight index y[ii] by jj and check if objective is improved, iterately
								jj = random.randint(0, num_w-1)
								# make sure jj is not the same as the current selected weight for the ii class
								while jj == y[ii]:
									jj = random.randint(0, num_w-1)
								y_new[ii] = jj  # e.g., [3, 9, 9, 9, 9]
								pred_corrected = lblInfer(y_new, pp, w)
								z_new = ObjFunction(pred_corrected, lbl, B, nTrue, rTrue, beta, tau, k, num_class)
								n_generate+=1
								z_total += z_new
								# record max and min z during SA
								z_min = min(z_min, z_new)
								z_max = max(z_max, z_new)
								# update correction weight y
								# SA allows worse z to jump out local minima, help avoid local minima
								# z_cur keeps record of current result (could be worse than z_best)
								# z_best keeps record of historial best result
								if z_new <= z_cur:
									n_accept += 1
									y_cur = y_new.copy()
									z_cur = z_new
									if z_new < z_best:
										z_best = z_new
										y_best = y_new.copy()
								elif random.uniform(0, 1) < np.exp((z_cur - z_new) / r_temp):
									y_cur = y_new.copy()
									z_cur = z_new
								# SA inner loop stopping criteria
								if n_accept>=iter_min or n_generate>=iter_max:
									break
							r_temp = r_temp * low_tem
							end_iteration_time = time.time()
							iteration_duration = end_iteration_time - start_iteration_time
							accept_rate = n_accept / n_generate if n_generate > 0 else 0
							total_run_time = end_iteration_time - start_time
							z_average = z_total / n_generate
							iteration_info_format = "{:^10d} {:^15.3f} {:^10.4f} {:^15d} {:^15d} {:^16d} {:^15d} {:^20d} {:^17.2f} {:^17.2f}"
							iteration_info = iteration_info_format.format(T, r_temp, accept_rate, n_accept, n_generate, int(z_average),
																		  int(z_min), int(z_max), total_run_time, iteration_duration)
							print(iteration_info)
							# SA outer loop stopping criterion
							if r_temp < T_min:
								break

						# Update predictions with correction weights
						pred_corrected = lblInfer(y_best, pp, w)
						z_check=ObjFunction(pred_corrected, lbl, B, nTrue, rTrue, beta, tau, k, num_class)
						print(z_check, z_best)
						if z_check != z_best:
							print('z_check != z_best')
							print('____________________')
						print(f'W is on a scale of {num_w}')
						print('The selected correction weights: '+str(y_best))
						print('Objective function value:  '+str(z_best))

						# Train evaluation
						train_acc = accuracy_score(lbl, pred_corrected)
						print('===Train acc===', train_acc)
						score_report = classification_report(lbl, pred_corrected)
						print(score_report)
						train_matrix = confusion_matrix(lbl, pred_corrected)
						train_class_acc = train_matrix.diagonal()/train_matrix.sum(axis=1)
						print('train class acc.', train_class_acc)
						train_bias = compute_bias(train_class_acc)
						print('train bias ', train_bias)
						train_pmi_total, train_pmi_class = compute_pmi(train_class_acc, pred_corrected, B, len(lbl), k)
						print('train PMI ', train_pmi_total, train_pmi_class)

						# Dev evaluation
						pred_dev_corrected = lblInfer(y_best, pp_dev, w)
						dev_acc = accuracy_score(lbl_dev, pred_dev_corrected)
						print('===dev acc===', dev_acc)
						score_report = classification_report(lbl_dev, pred_dev_corrected)
						print(score_report)
						dev_matrix = confusion_matrix(lbl_dev, pred_dev_corrected)
						dev_class_acc = dev_matrix.diagonal()/dev_matrix.sum(axis=1)
						print('dev class acc.', dev_class_acc)
						dev_bias = compute_bias(dev_class_acc)
						print('dev bias ', dev_bias)
						dev_pmi_total, dev_pmi_class = compute_pmi(dev_class_acc, pred_dev_corrected, B_dev, len(lbl_dev), k)
						print('dev PMI ', dev_pmi_total, dev_pmi_class)

						end_time = time.time()
						run_time = end_time - start_time
						print("Program execution time:"+str("{:.4}".format(run_time))+'  seconds')

						# Get the best y_best on dev set
						pred_test, lbl_test = [], []
						# Read labels and predicted token likelihoods
						file_path = os.path.join(config.vec_dir, 'test.txt')
						raw_test = open(file_path).readlines()
						lbl_test = [int(x.strip().split()[1]) for x in raw_test]
						pp_test = [[float(x.strip().split()[n + 3]) for n in range(num_class)] for x in raw_test]

						assert len(pp_test) == len(lbl_test)

						# Correct test predictions
						B_dict = Counter(lbl_test)
						B_test = [B_dict[x] for x in sorted(B_dict.keys())]
						# print('Test support:', B_test)

						pred_test_corrected = lblInfer(y_best, pp_test, w)
						test_acc = accuracy_score(lbl_test, pred_test_corrected)
						print('===Test acc===', test_acc)
						score_report = classification_report(lbl_test, pred_test_corrected)
						print(score_report)
						matrix = confusion_matrix(lbl_test, pred_test_corrected)
						test_class_acc = matrix.diagonal()/matrix.sum(axis=1)
						print('test class acc. ', test_class_acc)
						test_bias = compute_bias(test_class_acc)
						print('test bias ', test_bias)
						test_pmi_total, test_pmi_class = compute_pmi(test_class_acc, pred_test_corrected, B_test, len(lbl_test), k)
						print('test PMI ', test_pmi_total, test_pmi_class)
						print('=======train_size {} rseed {} beta {} tau {} k {} w {} ends========='.format(num_sample, rseed, beta, tau, k, len(w)))


if __name__ == '__main__':

	parser = argparse.ArgumentParser()
	parser.add_argument('-d', "--vec_dir", default=None, help="Vector directory to raw output class proabilities by an LLM.")
	# Load config
	parser.add_argument('-c', '--config', type=str, default=None)

	args = parser.parse_args()
	if args.config is not None:
		cur_config_path = args.config
	else:
		cur_config_path = os.path.join("config", "default_params.json")

	update_config = vars(args)
	print('Update config', update_config)
	config = Config(cur_config_path, update_config)
	main(config)



