import numpy as np
import random
import time
from models_dmnl_uniform import ucb_mnl, ts_mnl, ofu_mnl_plus, ofu_mnl_dr, ofu_dmnl_full, ofu_dmnl, set_function
import argparse
import pickle


parser = argparse.ArgumentParser(description='mnl bandit for uniform rewards')

parser.add_argument('--C', type = int, default=1000, help='number of item candidates')
parser.add_argument('--S', type = int, default=1000, help='number of item set')
parser.add_argument('--N', type = int, default=50, help='number of base items')
parser.add_argument('--K', type = int, default=5, help='size of assortment')
parser.add_argument('--d', type = int, default=5, help='feature dimension')
parser.add_argument('--ncat', type = int, default=5, help='category feature dimension')
parser.add_argument('--LamD', type = float, default=0.4, help='true balancing par for diversity')
parser.add_argument('--step',type = float, default=1, help='maxstep' )
parser.add_argument('--omega',type = float, default=0.3, help='strict submodular' )
parser.add_argument('--dist', type = int, default=0, help='context distribution - 0:gaussian, 1:uniform, 2:elliptical')
parser.add_argument('--id', type = int, default=10, help='job ID') 
parser.add_argument('--T', type = int, default=10000, help='horizon')
parser.add_argument('--simul_n', type = int, default=10, help='number of simulation')


class dmnlEnv:
	def __init__(self, theta, LamD, N, K, omega, step):
		super(dmnlEnv, self).__init__()
		self.theta = theta
		self.LamD=LamD
		self.N=N
		self.K = K
		self.vzero = 1
		self.omega=omega
		self.step=step
        
	def compute_rwd_div(self, means, setmean):
		u = np.exp(means+setmean)
		uSum = self.vzero + u.sum()
		prob = np.append(u, [self.vzero])/uSum
		rwd = u.sum()/uSum
		Y = np.random.multinomial(1, prob)
		return rwd, Y
	
def set_function(cat_info,omega,step):         #(1-r^(1/step))-strict-sub-modular
    r=(1-omega)**step
    coverage=np.sum(np.max(cat_info, axis=0))
    ncat=cat_info.shape[0]
    set_score = (1-r**coverage)/(1-r**ncat)
    return set_score
	



def main():
	args = parser.parse_args()
	random.seed(args.id)

	C = args.C
	S = args.S
	N = args.N
	K = args.K
	ncat=args.ncat
	LamD=args.LamD
	d = args.d
	step=args.step
	omega=args.omega
	dist = args.dist

	T = args.T
	simul_n=args.simul_n

	loaditem="./data/itemset_C={}_d={}_ncat={}_step={}_dist={}.csv".format(C, d, ncat, step, dist)
	itemset=np.loadtxt(loaditem, delimiter=",")
	features=itemset[:, :d]
	category=itemset[:, d:]

	loaddata="./data/dataset_S={}_N={}_K={}_d={}_ncat={}_LamD={}_step={}_omega={}_dist={}.pkl".format(S, N, K, d, ncat, LamD, step, omega, dist)
	with open(loaddata, "rb") as f:
		loaded = pickle.load(f)

	theta=loaded["Theta"]
	Sample_set=loaded["Sample_set"]
	True_opt_set=loaded["True_opt_set"]
	Gamma_opt_set=loaded["Gamma_opt_set"]
	True_opt_rwd=loaded["True_opt_rwd"]
	Gamma_opt_rwd=loaded["Gamma_opt_rwd"]

	vzero=1
	kappa = np.exp(-1) / (vzero +K * np.exp(1))**2 


	regret_savename = "./results/dmnlBandit_LamD={}_N={}_K={}_d={}_ncat={}_omega={}_step={}_dist={}_id={}_regret.csv".format(LamD, N, K, d, ncat, omega, step, args.dist, args.id)
	gamma_regret_savename="./results/dmnlBandit_LamD={}_N={}_K={}_d={}_ncat={}_omega={}_step={}_dist={}_id={}_gamma_regret.csv".format(LamD, N, K, d, ncat,omega, step, args.dist, args.id)
	cum_runtime_savename = "./results/dmnlBandit_LamD={}_N={}_K={}_d={}_ncat={}_omega={}_step={}_dist={}_id={}_cum_runtime.csv".format(LamD, N, K, d, ncat, omega, step,args.dist, args.id)		

	

	cumulated_regret = [[] for i in range(6)]
	gamma_cumulated_regret = [[] for i in range(6)]
	cumulated_time = [[] for i in range(6)]


	for simul in range(simul_n):
		print(simul, "-th simulation started")

		env=dmnlEnv(theta, LamD, N, K, omega, step)
		M1=ucb_mnl(N=N, K=K, d=d, kappa = kappa)
		M2=ts_mnl(N=N, K=K, d=d, kappa = kappa)
		M3=ofu_mnl_plus(N=N, K=K, d=d, kappa = kappa, vzero = vzero)
		M4=ofu_mnl_dr(LamD=1, N=N, K=K, d=d, omega=omega, step=step, kappa = kappa)
		M5=ofu_dmnl_full(N=N, K=K, d=d, omega=omega, step=step, kappa = kappa)
		M6=ofu_dmnl(N=N, K=K, d=d, omega=omega, step=step, kappa = kappa)


		RWD1=list()
		RWD2=list()
		RWD3=list()
		RWD4=list()
		RWD5=list()
		RWD6=list()
		true_optRWD=list()
		gamma_optRWD=list()

		TIME1=list()
		TIME2=list()
		TIME3=list()
		TIME4=list()
		TIME5=list()
		TIME6=list()

		for t in range(T):
			sample_id = random.choice(range(S))
			samples=Sample_set[sample_id] #Set of N items 
			x=features[samples]
			cat=category[samples]

			start_time = time.time()
			S1=M1.choose_S(t+1,x)
			rwd1, Y1 =  env.compute_rwd_div(np.dot(x[S1,:],theta), LamD*set_function(cat[S1], omega, step))
			RWD1.append(rwd1)
			M1.update_theta(Y1,t+1)
			TIME1.append(time.time() - start_time)
			RWD1.append(0)
			TIME1.append(0)

			start_time = time.time()
			S2=M2.choose_S(t+1,x)
			rwd2, Y2 = env.compute_rwd_div(np.dot(x[S2,:],theta), LamD*set_function(cat[S2], omega, step))
			RWD2.append(rwd2)
			M2.update_theta(Y2,t+1)
			TIME2.append(time.time() - start_time)
			RWD2.append(0)
			TIME2.append(0)

			start_time = time.time()
			S3=M3.choose_S(t+1,x)
			rwd3, Y3 = env.compute_rwd_div(np.dot(x[S3,:],theta), LamD*set_function(cat[S3], omega, step))
			RWD3.append(rwd3)
			M3.update_state(Y3)
			TIME3.append(time.time() - start_time)

			start_time = time.time()
			S4=M4.choose_S(t+1,x, cat)
			rwd4, Y4 = env.compute_rwd_div(np.dot(x[S4,:],theta), LamD*set_function(cat[S4], omega, step))
			RWD4.append(rwd4)
			M4.update_state(Y4)
			TIME4.append(time.time() - start_time)

			start_time = time.time()
			S5=M5.choose_S(t+1,x,cat)
			rwd5, Y5 = env.compute_rwd_div(np.dot(x[S5,:],theta), LamD*set_function(cat[S5], omega, step))
			RWD5.append(rwd5)
			M5.update_state(Y5)
			TIME5.append(time.time() - start_time)
			RWD5.append(0)
			TIME5.append(0)

			start_time = time.time()
			S6=M6.choose_S(t+1,x,cat)
			rwd6, Y6 = env.compute_rwd_div(np.dot(x[S6,:],theta), LamD*set_function(cat[S6], omega, step))
			RWD6.append(rwd6)
			M6.update_state(Y6)
			TIME6.append(time.time() - start_time)

			true_opt_rwd=True_opt_rwd[sample_id]
			gamma_opt_rwd = Gamma_opt_rwd[sample_id]

			#if t%1000==999:
				#print("Round: ", t+1, "Regret:", np.round(true_opt_rwd - rwd1,2), np.round(true_opt_rwd - rwd2,2), np.round(true_opt_rwd - rwd3,2), np.round(true_opt_rwd - rwd4,2), np.round(true_opt_rwd - rwd5,2),  np.round(true_opt_rwd - rwd6,2))
				#print("Round: ", t+1, "Gamma Regret:", np.round(gamma_opt_rwd - rwd1,2), np.round(gamma_opt_rwd - rwd2,2), np.round(gamma_opt_rwd - rwd3,2), np.round(gamma_opt_rwd - rwd4,2), np.round(gamma_opt_rwd - rwd5,2), np.round(gamma_opt_rwd - rwd6,2))
			true_optRWD.append(true_opt_rwd)
			gamma_optRWD.append(gamma_opt_rwd)

			if t==T-1:
				print("Number of adaptive exploration rounds of OFU-DMNL:", M6.numinit)

		cumulated_regret[0].append(np.cumsum(true_optRWD)-np.cumsum(RWD1))
		cumulated_regret[1].append(np.cumsum(true_optRWD)-np.cumsum(RWD2))
		cumulated_regret[2].append(np.cumsum(true_optRWD)-np.cumsum(RWD3))
		cumulated_regret[3].append(np.cumsum(true_optRWD)-np.cumsum(RWD4))
		cumulated_regret[4].append(np.cumsum(true_optRWD)-np.cumsum(RWD5))
		cumulated_regret[5].append(np.cumsum(true_optRWD)-np.cumsum(RWD6))

		gamma_cumulated_regret[0].append(np.cumsum(gamma_optRWD)-np.cumsum(RWD1))
		gamma_cumulated_regret[1].append(np.cumsum(gamma_optRWD)-np.cumsum(RWD2))
		gamma_cumulated_regret[2].append(np.cumsum(gamma_optRWD)-np.cumsum(RWD3))
		gamma_cumulated_regret[3].append(np.cumsum(gamma_optRWD)-np.cumsum(RWD4))
		gamma_cumulated_regret[4].append(np.cumsum(gamma_optRWD)-np.cumsum(RWD5))
		gamma_cumulated_regret[5].append(np.cumsum(gamma_optRWD)-np.cumsum(RWD6))

		cumulated_time[0].append(np.cumsum(TIME1))
		cumulated_time[1].append(np.cumsum(TIME2))
		cumulated_time[2].append(np.cumsum(TIME3))
		cumulated_time[3].append(np.cumsum(TIME4))
		cumulated_time[4].append(np.cumsum(TIME5))
		cumulated_time[5].append(np.cumsum(TIME6))

	regret = np.vstack([
		np.asarray(cumulated_regret[0]),
		np.asarray(cumulated_regret[1]),
		np.asarray(cumulated_regret[2]),
		np.asarray(cumulated_regret[3]),
		np.asarray(cumulated_regret[4]),
		np.asarray(cumulated_regret[5])
		])
	np.savetxt(regret_savename, regret, delimiter=",")

	gamma_regret = np.vstack([
		np.asarray(gamma_cumulated_regret[0]),
		np.asarray(gamma_cumulated_regret[1]),
		np.asarray(gamma_cumulated_regret[2]),
		np.asarray(gamma_cumulated_regret[3]),
		np.asarray(gamma_cumulated_regret[4]),
		np.asarray(gamma_cumulated_regret[5])
		])
	np.savetxt(gamma_regret_savename, gamma_regret, delimiter=",")

	cum_runtime = np.vstack([
		np.asarray(cumulated_time[0]),
		np.asarray(cumulated_time[1]),
		np.asarray(cumulated_time[2]),
		np.asarray(cumulated_time[3]),
		np.asarray(cumulated_time[4]),
		np.asarray(cumulated_time[5])
		])
	np.savetxt(cum_runtime_savename, cum_runtime, delimiter=",")


	
if __name__ == '__main__':
	main()
