import numpy as np
import random
from models import eps_greedy_mnl, ucb_mnl, ts_mnl, ofu_mnl_plus, onl_mnl
import argparse
import sys
import os

import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='Run MNL bandit environment simulation')

parser.add_argument('--N', type = int, default=100, help='number of base items')
parser.add_argument('--K', type = int, default=5, help='size of assortment')
parser.add_argument('--T', type = int, default=1000, help='horizon')
parser.add_argument('--d', type = int, default=3, help='feature dimension')
parser.add_argument('--dist', type = int, choices=[0, 1], default=0, help='context distribution. 0:gaussian, 1:uniform')
parser.add_argument('--vzero', type = float, default=1.0, help='utility for the outside option')
parser.add_argument('--hidden_m', type = int, default=10, help='number of neurons in hidden layer')
parser.add_argument('--n', type = int, default=100, help='length of uniform exploration')
parser.add_argument('--n_simul', type = int, default=1, help='number of simulation')
parser.add_argument('--seed', type = int, default=1, help='first random seed value to use. n_simul experiments will run using range(seed, seed+n_simul)')
parser.add_argument('--env_type', type = int, choices=[0, 1], default=0, help='environment type. 0:2-layer-NN 1:cosine')

class TwoLayerNN(nn.Module):
    """
    Neural network with two fully-connected layers.
    """
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.dx = input_dim
        self.hidden_m = hidden_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float64)
        elif x.dtype != torch.float64:
            x = x.to(dtype=torch.float64)
            
        x = torch.sigmoid(self.fc1(x))
        return self.fc2(x).squeeze(-1)
    
    def set_model_weights_from_vector(self, w):
        w = torch.tensor(w, dtype=torch.float64) if not isinstance(w, torch.Tensor) else w.double()
        fc1_w_end = self.dx * self.hidden_m
        fc1_b_end = fc1_w_end + self.hidden_m
        fc2_w_end = fc1_b_end + self.hidden_m

        with torch.no_grad():
            self.fc1.weight.copy_(w[:fc1_w_end].view(self.hidden_m, self.dx))
            self.fc1.bias.copy_(w[fc1_w_end:fc1_b_end])
            self.fc2.weight.copy_(w[fc1_b_end:fc2_w_end].view(1, self.hidden_m))
            self.fc2.bias.copy_(w[fc2_w_end:])

class mnlEnv_nn: 
    """
    Multinomial logistic bandit environment with 2-layer neural network utiliy
    """
    def __init__(self, model, K, vzero=1):
        super(mnlEnv_nn, self).__init__()
        self.K = K
        self.vzero = vzero
        self.model = model.double()

    def compute_rwd(self, x_S):
        x_S = torch.tensor(x_S, dtype=torch.float64)
        u = self.model(x_S)
        u_ = np.append(u.detach().numpy(), 0.0)
        u_ -= np.max(u_)
        exp_u = np.exp(u_)
        prob = exp_u / np.sum(exp_u)
        ones = np.ones_like(prob)
        ones[-1] = 0.0
        rwd = np.dot(prob, ones)
        Y = np.random.multinomial(1, prob) # User choice
        return rwd, Y

    def get_opt_rwd(self, x):
        x = torch.tensor(x, dtype=torch.float64)
        u = self.model(x)
        exp_u = torch.exp(u).detach().numpy()
        opt_exp_u = np.sort(exp_u)[::-1][:self.K]
        denom = self.vzero + np.sum(opt_exp_u)
        rwd = np.sum(opt_exp_u)/denom
        return rwd

class mnlEnv_cos:
    """
    Multinomial logistic bandit environment with cosine utiliy
    """
    def __init__(self, w, K, vzero=1):
        super(mnlEnv_cos, self).__init__()
        self.K = K
        self.vzero = vzero
        self.w = w

    def compute_rwd(self, x_S):
        u = np.cos(2*np.pi*np.dot(x_S, self.w)) - np.dot(x_S, self.w)/2
        u_ = np.append(u, 0.0)
        u_ -= np.max(u_)
        exp_u = np.exp(u_)
        prob = exp_u / np.sum(exp_u)
        ones = np.ones_like(prob)
        ones[-1] = 0.0
        rwd = np.dot(prob, ones)
        Y = np.random.multinomial(1, prob) # User choice
        return rwd, Y

    def get_opt_rwd(self, x):
        u = np.cos(2*np.pi*np.dot(x, self.w)) - np.dot(x, self.w)/2
        exp_u = np.exp(u)
        opt_exp_u = np.sort(exp_u)[::-1][:self.K]
        denom = self.vzero + np.sum(opt_exp_u)
        rwd = np.sum(opt_exp_u)/denom
        return rwd    
    
def main():
    # Parse arguments
	try:
		args = parser.parse_args() 
	except argparse.ArgumentError as e:
		print(f"Argument error: {e}")
		sys.exit(1)

	N = args.N
	K = args.K
	d = args.d
	T = args.T
	n = args.n
	vzero = args.vzero
	dist = args.dist
	hidden_m = args.hidden_m
	dx = args.d
	env_type = args.env_type
	n_simul = args.n_simul
      
	seeds = range(args.seed, args.seed + n_simul)
	kappa = np.exp(-1) / (vzero + K * np.exp(1)) ** 2
	
	cumulated_regret = [[] for _ in range(5)]

	for seed in seeds:
        # Set random seeds
		random.seed(seed)
		np.random.seed(seed)
		torch.manual_seed(seed)
		torch.cuda.manual_seed(seed)   
        
		if env_type==0:
			dw = dx * hidden_m + (2 * hidden_m) + 1 # Num. parameters of two layer NN
			w_star = np.random.uniform(-1., 1., dw)
			w_star = torch.from_numpy(w_star)
			model_env = TwoLayerNN(input_dim=dx, hidden_dim=hidden_m).double()
			model_env.set_model_weights_from_vector(w_star)
			env = mnlEnv_nn(model=model_env, K=K)
		elif env_type==1:
			w_star = np.random.uniform(-1., 1., d)
			env = mnlEnv_cos(w=w_star, K=K)
            
		M1 = ucb_mnl(N=N, K=K, d=d, kappa=kappa)
		M2 = ts_mnl(N=N, K=K, d=d, kappa=kappa)
		M3 = ofu_mnl_plus(N=N, K=K, d=d, kappa=kappa, vzero=vzero)
		M4 = eps_greedy_mnl(N, K, dx, hidden_m, T, eps=0.1, eps_decay=0.995, eps_min=0.001)
		M5 = onl_mnl(N, K, dx, hidden_m, n, T, kappa, c_lam=7.2e-10, c_beta=1e-15)

		RWD1, RWD2, RWD3, RWD4, RWD5, optRWD = [], [], [], [], [], []

		for t in tqdm(range(T)):
			if dist == 0:
				x = np.random.randn(N, d) 
			elif dist == 1:
				x = np.random.uniform(low=-3.0, high=3.0, size=(N, d))

			S1=M1.choose_S(t+1,x)
			rwd1, Y1 = env.compute_rwd(x[S1,:])
			RWD1.append(rwd1)
			M1.update_theta(Y1,t+1)

			S2=M2.choose_S(t+1,x)
			rwd2, Y2 = env.compute_rwd(x[S2,:])
			RWD2.append(rwd2)
			M2.update_theta(Y2,t+1)

			S3=M3.choose_S(t+1,x)
			rwd3, Y3 = env.compute_rwd(x[S3,:])
			RWD3.append(rwd3)
			M3.update_state(Y3)

			S4=M4.choose_S(t+1,x)
			rwd4, Y4 = env.compute_rwd(x[S4,:])
			RWD4.append(rwd4)
			M4.update_w(Y4, t+1)

			S5 = M5.choose_S(t+1, x)
			rwd5, Y5 = env.compute_rwd(x[S5,:])
			RWD5.append(rwd5)
			M5.update_w(Y5, t+1)

			opt_rwd = env.get_opt_rwd(x)
			optRWD.append(opt_rwd)

		cumulated_regret[0].append(np.cumsum(optRWD)-np.cumsum(RWD1))
		cumulated_regret[1].append(np.cumsum(optRWD)-np.cumsum(RWD2))
		cumulated_regret[2].append(np.cumsum(optRWD)-np.cumsum(RWD3))
		cumulated_regret[3].append(np.cumsum(optRWD)-np.cumsum(RWD4))
		cumulated_regret[4].append(np.cumsum(optRWD)-np.cumsum(RWD5))

	# Convert to numpy arrays
	cumulated_regret[0] = np.asarray(cumulated_regret[0])
	cumulated_regret[1] = np.asarray(cumulated_regret[1])
	cumulated_regret[2] = np.asarray(cumulated_regret[2])
	cumulated_regret[3] = np.asarray(cumulated_regret[3])
	cumulated_regret[4] = np.asarray(cumulated_regret[4])
     
	# Compute mean/std    
	mean1, std1 = cumulated_regret[0].mean(axis=0), cumulated_regret[0].std(axis=0)
	mean2, std2 = cumulated_regret[1].mean(axis=0), cumulated_regret[1].std(axis=0)
	mean3, std3 = cumulated_regret[2].mean(axis=0), cumulated_regret[2].std(axis=0)
	mean4, std4 = cumulated_regret[3].mean(axis=0), cumulated_regret[3].std(axis=0)
	mean5, std5 = cumulated_regret[4].mean(axis=0), cumulated_regret[4].std(axis=0)
	
    # Visualization
	x = np.arange(T)
	plot_indices = np.linspace(0, len(x) - 1, 6, dtype=int)
	plt.figure(figsize=(10, 6))
	fontsize = 20

	plt.plot(x, mean1, label=r"$\mathtt{UCB-MNL}$", color='tab:purple')
	plt.errorbar(x[plot_indices], mean1[plot_indices], yerr=std1[plot_indices], fmt='s', color='tab:purple',
				markerfacecolor='white', capsize=5, markersize=8)

	plt.plot(x, mean2, label=r"$\mathtt{TS-MNL}$", color='tab:green')
	plt.errorbar(x[plot_indices], mean2[plot_indices], yerr=std2[plot_indices], fmt='^', color='tab:green',
				markerfacecolor='white', capsize=5, markersize=8)

	plt.plot(x, mean3, label=r"$\mathtt{OFU-MNL+}$", color='tab:blue')
	plt.errorbar(x[plot_indices], mean3[plot_indices], yerr=std3[plot_indices], fmt='o', color='tab:blue',
				markerfacecolor='white', capsize=5, markersize=8)
    
	plt.plot(x, mean4, label=r"$\varepsilon-\mathtt{greedy-MNL}$", color='tab:brown')
	plt.errorbar(x[plot_indices], mean4[plot_indices], yerr=std4[plot_indices], fmt='v', color='tab:brown',
				markerfacecolor='white', capsize=5, markersize=8)

	plt.plot(x, mean5, label=r"$\mathtt{ONL-MNL}$", color='tab:red')
	plt.errorbar(x[plot_indices], mean5[plot_indices], yerr=std5[plot_indices], fmt='D', color='tab:red',
				markerfacecolor='white', capsize=5, markersize=8)
    

	plt.tick_params(axis='both', labelsize=fontsize) 
	plt.xlabel("Round (t)", fontsize=fontsize)
	plt.ylabel("Cumulative Regret", fontsize=fontsize)
      
	if env_type==0:
		if dist==0:
			plt.title(r"$\mathbf{x}\sim \mathtt{Gaussian}, \, f_{\mathbf{w}^*}(\mathbf{x})= \mathtt{linear2(sigmoid(linear1(\mathbf{x})))}$ ", fontsize=fontsize)
		elif dist==1:
			plt.title(r"$\mathbf{x}\sim \mathtt{Uniform}, \, f_{\mathbf{w}^*}(\mathbf{x})= \mathtt{linear2(sigmoid(linear1(\mathbf{x})))}$ ", fontsize=fontsize)
	elif env_type==1:
		if dist==0:
			plt.title(r"$\mathbf{x} \sim \mathtt{Gaussian}, \quad f_{\mathbf{w}^*}(\mathbf{x}) = \cos (2\pi (\mathbf{x}^\top \mathbf{w}^*)) - (\mathbf{x}^\top \mathbf{w}^*)/2$", fontsize=fontsize);
		elif dist==1:
			plt.title(r"$\mathbf{x} \sim \mathtt{Uniform}, \quad f_{\mathbf{w}^*}(\mathbf{x}) = \cos (2\pi (\mathbf{x}^\top \mathbf{w}^*)) - (\mathbf{x}^\top \mathbf{w}^*)/2$", fontsize=fontsize)

	plt.legend(loc='upper left', fontsize=fontsize-1)
	plt.grid(True)
	plt.tight_layout()
    
	# Save regret data and figure
	os.makedirs("./results", exist_ok=True)
	savepath = "./results/mnlBandit_N={}_K={}_d={}_m={}_n={}_dist={}_env={}".format(N, K, d, hidden_m, n, dist, env_type)
	regret_savepath = f"{savepath}_regret.csv"
	figure_savepath = f"{savepath}_figure.png"
	regret = np.vstack([cumulated_regret[0], cumulated_regret[1], cumulated_regret[2], cumulated_regret[3], cumulated_regret[4]])
	np.savetxt(regret_savepath, regret, delimiter=",")
	plt.savefig(figure_savepath)

if __name__ == '__main__':
	main()
