import numpy as np
import pickle
from scipy.optimize import brentq
from scipy.optimize import linprog
import copy
import time
import argparse
from Algos import LinUCB, RestrictedLinUCB
import warnings
warnings.filterwarnings("ignore")


class Environment():
    def __init__(self, dim, num_arms, arms=None, theta=None):
        self.d = dim
        self.num_arms = num_arms
        if theta is None: 
            self.generate_theta()
        else:
            self.theta = theta
        if not arms:
            self.fixed_arms = False 
            self.generate_arms()    
        else:
            self.fixed_arms = True
            self.arms = arms
        if self.fixed_arms:
            self.static_mean_rewards = [np.dot(arm.T, self.theta) for arm in self.arms]
            self.static_best_arm = np.argmax(self.static_mean_rewards)
            if self.static_best_arm.size>1:
                print('Best arm is not unique')

    def generate_arms(self):
        if not self.fixed_arms:
            self.arms = [np.random.normal(size=(self.d, 1)) for i in range(self.num_arms)]
            self.arms = [arm/np.linalg.norm(arm) for arm in self.arms]
            self.current_means = [np.dot(self.theta.T, arm) for arm in self.arms]


    def refine_bounds(self, arms, upper, lower):
        new_upper, new_lower = np.zeros(len(upper)), np.zeros(len(upper))
        A = copy.deepcopy(arms)
        A = np.asarray(A)
        A = A[:,:,0]
        A = np.vstack((A,-1*A))
        b = np.append(upper, -1*lower)
        for i in range(len(arms)):
            lims = (-1, 1)
            obj = arms[i]
            opt1 = linprog(c=obj, A_ub=A, b_ub=b, bounds=lims)
            new_lower[i] = opt1.fun
            opt2 = linprog(c=-1*obj, A_ub=A, b_ub=b, bounds=lims)
            new_upper[i] = -1*opt2.fun
        return new_upper, new_lower

    def gen_bounds(self):
        self.ub = np.zeros(self.num_arms)
        self.lb = np.zeros(self.num_arms)
        for i in range(self.num_arms):
            ub_noise = 0.5*np.random.uniform()
            if self.current_means[i]<0:
                self.ub[i] = np.clip(self.current_means[i]+ub_noise, -1, 0)
            else:
                self.ub[i] = np.clip(self.current_means[i]+ub_noise, 0, 1)
            lb_noise = 0.5*np.random.uniform()
            if self.current_means[i]<0:
                self.lb[i] = np.clip(self.current_means[i]-lb_noise, -1, 0)
            else:
                self.lb[i] = np.clip(self.current_means[i]-lb_noise, 0, 1)
        self.ub, self.lb = self.refine_bounds(self.arms, self.ub, self.lb)
        return self.ub, self.lb
    
    def generate_theta(self):
        self.theta = np.random.normal(size=(self.d,1))
        self.theta /= np.linalg.norm(self.theta)

    def get_best_arm(self):
        if self.fixed_arms:
            return self.static_best_arm
        mean_rewards = [np.dot(self.theta.T, arm) for arm in self.arms]
        best_arm = np.argmax(mean_rewards)
        return best_arm
    
    def iterate(self):
        self.generate_arms()

    def gen_rewards(self):
        max_idx = np.argmax(self.current_means)
        rews = np.zeros(self.num_arms)
        for i in range(self.num_arms):
            noise = 0.5*np.random.normal()
            if self.current_means[i] < 0:
                rews[i] = np.clip(self.current_means[i]+noise, -1, 0)
            elif self.current_means[i] >= 0:
                rews[i] = np.clip(self.current_means[i]+noise, 0, 1)
        return rews, max_idx
    
    def gen_rewards_bern(self):
        rews = np.zeros(self.num_arms)
        max_idx = np.argmax(self.current_means)
        for i in range(self.num_arms):
            noise = np.random.uniform()
            temp = int(np.abs(self.current_means[i])<=noise)
            if rews[i]<0: 
                rews[i] = -1*temp
            else:
                rews[i] = temp
        return rews, max_idx

def func(x,m):
    fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)
    gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)
    return gx


parser = argparse.ArgumentParser()
parser.add_argument('--Dimension', '-dim', help='Dimension', type=int)
parser.add_argument('--Number_of_arms', '-arms', help='Number of arms', type=int)
parser.add_argument('--Number_of_Instances', '-num_inst', help='Number of instances to run', type=int)
parser.add_argument('--Number_of_Iterations', '-num_iter', help='Number of iterations per instance', type=int)
parser.add_argument('--Save_suffix', '-suf', help='Suffix to savefile', type=str)
args = parser.parse_args()
dim = args.Dimension
num_arms = args.Number_of_arms
num_inst = args.Number_of_Instances
num_iter = args.Number_of_Iterations
suffix = args.Save_suffix

theta = np.arange(dim)
theta = theta/np.linalg.norm(theta)
theta = theta.reshape((len(theta),1))

sg = dict()
for i in np.arange(0,0.5,0.0005):
    i = round(i, 4)
    x_best = brentq(func, 0.0005, 100, args = i)
    sg[i] = (2/x_best**2)*np.log(i*np.exp((1-i)*x_best) + (1-i)*np.exp(-i*x_best))
sg[0.5] = 0.25

SIGMA = 0.25**0.5
DELTA = 0.01

env = Environment(dim, num_arms, theta=theta)
l1 = LinUCB(dim, SIGMA, DELTA)
l2 = RestrictedLinUCB(dim, DELTA, sg)

reg1 = np.zeros((1, num_iter))
reg2 = np.zeros((1, num_iter))

time_l1 = []
time_l2 = []
for inst in range(num_inst):
    l1.restart()
    l2.restart()
    for iter in range(num_iter):
        #t = time.time()
        rews, true_best = env.gen_rewards()
        upper, lower = env.gen_bounds()
        arms = env.arms
        t1 = time.time()
        #print(f'Finished env setup in {t1-t} seconds')
        l1.iterate(arms, rews, true_best)
        t2 = time.time()
        time_l1 += [t2-t1]
        t2 = time.time()
        #print(f'Finished LinUCB in {t2-t1} seconds')
        l2.iterate(arms, upper, lower, rews, true_best)
        t3 = time.time()
        time_l2 += [t3-t2]
        #print(f'Finished RestrictedLinUCB in {t3-t2} seconds')
        env.iterate()
        #t4= time.time()
        #print(f'Finished generating arms in {t4-t3} seconds')  
    iter_reg1 = np.asarray(l1.get_regret(), dtype='object')
    reg1 = np.vstack((reg1, iter_reg1[1:]))
    iter_reg2 = np.asarray(l2.get_regret(), dtype='object')
    reg2 = np.vstack((reg2, iter_reg2[1:]))

reg1, reg2 = reg1[1:,:], reg2[1:,:]

run_dict = dict()
run_dict['reg1'] = reg1
run_dict['reg2'] = reg2
run_dict['time_l1'] = time_l1
run_dict['time_l2'] = time_l2

savepath = './Files/LinBanditRun6/regfile'+suffix+'.p'
with open(savepath, 'wb') as f:
    pickle.dump(run_dict, f, protocol=pickle.HIGHEST_PROTOCOL)