import numpy as np
import random
from models import *
import time
import argparse
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='sparse bandit')

parser.add_argument('--N', type = int, default=20, help='number of arms')
parser.add_argument('--T', type = int, default=2000, help='horizon')
parser.add_argument('--d', type = int, default=100, help='feature dimension')
parser.add_argument('--s0', type = int, default=5, help='sparsity')
parser.add_argument('--R', type = float, default=0.5, help='noise')
parser.add_argument('--dist', type = int, default=0, help='context distribution - 0:gaussian, 1:uniform, 2:elliptical')
parser.add_argument('--id', type = int, default=0, help='job ID')
parser.add_argument('--simul_n', type = int, default = 20, help = 'Number of simulations')

# experiment 1 : --N 20 --s0 5 --d 100 --T 2000 --dist 1 --simul_n 100
# experiment 2 : --N 10 --s0 20 --d 100 --T 2000 --dist 2 --simul_n 100

def generate_instance(N, d, T, R, beta, dist):
	if dist == 1:
		V= 0.3*np.eye(N)+0.7*np.ones((N,N))
		contexts = [np.random.multivariate_normal(np.zeros(N),V,d).T for _ in range(T)]
	elif dist == 2:
		beta_norm = np.linalg.norm(beta, ord = 2) ** 2
		x_sub_optimal = np.random.randn(N - 1, d)
		fixed_ind = np.random.choice(np.where(beta == 0)[0], 5)
		x_sub_optimal[:, fixed_ind] = 5
		for i in range(N - 1):
			x_sub_optimal[i] += ( i / (N - 2) * 0.8 + 0.1 - np.dot(x_sub_optimal[i], beta) ) / beta_norm * beta
		contexts = []
		for _ in range(T):
			x_optimal = np.random.randn(1, d)
			x_optimal[:, fixed_ind] = 5
			x_optimal += (np.random.uniform(0.9, 1) - np.dot(x_optimal, beta)) / beta_norm * beta
			context = np.concatenate((x_sub_optimal, x_optimal), axis = 0)[np.random.permutation(N)]
			contexts.append(context)

	rwd = []
	regret = []

	for x in contexts:
		rwd_expected = np.dot(x, beta)
		err = R * np.random.randn(N)
		rwd.append(rwd_expected + err)
		optRWD = np.amax(rwd_expected)
		regret.append(optRWD - rwd_expected)
	return contexts, rwd, regret

def run(models, contexts, rwd, regret):
	# Returns (model_cnt, x, T) array
	model_cnt = len(models)
	regret_inst = [ list() for _ in range(model_cnt)]
	for t, x in enumerate(contexts):
		for i, model in enumerate(models):
			action = model.choose_a(t+1, x)
			regret_inst[i].append(regret[t][action])
			model.update_beta(rwd[t][action], t+1)
	regret_cumsum = np.cumsum(regret_inst, axis = 1)
	return regret_cumsum

def test_models(models_info, s0, N, d, T, R, dist, simul_n, savename):
	model_cnt = len(models_info)

	cumulated_regret_simulations = []

	for simul in range(simul_n):
		# print('Simulation', simul)
		beta=np.zeros(d)
		inds=np.random.choice(range(d),s0,replace=False)
		beta[inds]=np.random.uniform(-1.,1.,s0)
		beta /= np.linalg.norm(beta, ord = 1)

		contexts, rwd, regret = generate_instance(N, d, T, R, beta, dist)

		models = [ models_info[i][0](**models_info[i][2]) for i in range(model_cnt)]

		cumulated_regret_list = run(models, contexts, rwd, regret)
		cumulated_regret_simulations.append(cumulated_regret_list)

	cumulated_regret_simulations = np.asarray(cumulated_regret_simulations)
	avg_regret = np.mean(cumulated_regret_simulations, axis = 0)
	std_regret = np.std(cumulated_regret_simulations, axis = 0)

	for i in range(model_cnt):
		plt.errorbar(range(1, T+1), avg_regret[i], yerr = std_regret[i], errorevery= T // 6 + 1, label = models_info[i][1], capsize = 3)
	plt.legend()
	plt.xlabel('Time')
	plt.ylabel('Cumulative Regret')
	distnames = [' ', 'Correlated Gaussian',
			  'Fixed Sub-optimal Arms']

	plt.title('{}, d = {}, s = {}, K = {}'.format(distnames[dist], d, s0, N))
	plt.savefig(savename + '.png')
	np.savetxt(savename+'.csv', np.concatenate((avg_regret, std_regret), axis = 0), delimiter=",")
	print(avg_regret[:, -1])

def main():
	args = parser.parse_args()
	random.seed(args.id)

	N = args.N
	d = args.d
	s0 = args.s0
	R = args.R
	T = args.T
	simul_n= args.simul_n

	beta=np.zeros(d)
	inds=np.random.choice(range(d),s0,replace=False)
	beta[inds]=np.random.uniform(-1.,1.,s0)
	beta /= np.linalg.norm(beta, ord = 1)

	savename = "experiment1"

	lam0 = 0.3

	# experiment 1
	models_info = [
		(DRLassoBandit, 'DR Lasso', { 'lam1' : 1, 'lam2' : lam0,  'd' : d, 'N' : N, 'tc' : 1, 'tr' : True, 'zt' : 10}),
		(SALassoBandit, 'SA Lasso', {'lam0' : lam0, 'd' : d, 'N': N}),
		(THLassoBandit, 'TH Lasso', {'K' : N, 'lam0' : 0.02, 'd':d}),
		(LassoUCBBandit, 'L1-CB Lasso', { 'lam0' : lam0, 'd' : d, 'N' : N, 'tau' : 1}),
		(ESTCBandit, 'ESTC', {'M_0':150, 'lam0' : lam0, 'd':d}),
		(ETCLassoBandit, 'FS-WLasso', { 'M_0' : 10, 'w': 1, 'd': d, 'sigma': 0.06, 'delta' : 0.01}), 
		(FSLassoBandit, 'FS-Lasso', { 'q' : 5, 'h':0.02, 'lam1' : 0.08, 'lam2' : lam0,'d' : d, 'N' : N}), 
	]

	# experiment 2
	# models_info = [
	# 	(DRLassoBandit, 'DR Lasso', { 'lam1' : 1, 'lam2' : lam0,  'd' : d, 'N' : N, 'tc' : 1, 'tr' : True, 'zt' : 10}),
	# 	(SALassoBandit, 'SA Lasso', {'lam0' : lam0, 'd' : d, 'N': N}),
	# 	(THLassoBandit, 'TH Lasso', {'K' : N, 'lam0' : 0.02, 'd':d}),
	# 	(LassoUCBBandit, 'L1-CB Lasso', { 'lam0' : lam0, 'd' : d, 'N' : N, 'tau' : 1}),
	# 	(ESTCBandit, 'ESTC', {'M_0':200, 'lam0' : lam0, 'd':d}),
	# 	(ETCLassoBandit, 'FS-WLasso', { 'M_0' : 200, 'w': 1, 'd': d, 'sigma': 0.06, 'delta' : 0.01}), 
	# 	(FSLassoBandit, 'FS-Lasso', { 'q' : 12, 'h':0.01, 'lam1' : 0.02, 'lam2' : lam0,'d' : d, 'N' : N}), 
	# ]

	test_models(models_info, s0, N, d, T, R, args.dist, simul_n, savename)
	
if __name__ == '__main__':
	main()
