import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq
from scipy.optimize import linprog
import copy
import time
from Algos import LinUCB, RestrictedLinUCB
import warnings
import argparse
import pickle
warnings.filterwarnings("ignore")

#400 with variant 1
#280 with variant 2
class Environmnet():
    def __init__(self, vis_dim, hid_dim, num_arms, log_len=100000, log_noise_lvl = 2, seed=280):
        self.vis_dim = vis_dim
        self.hid_dim = hid_dim
        self.num_arms = num_arms
        self.seed = seed
        self.log_len = log_len
        self.log_noise_lvl = log_noise_lvl
        self.gen_env()

    def gen_env(self):
        np.random.seed(self.seed)
        self.theta_z = np.arange(1,self.vis_dim+1,1)/(np.linalg.norm(np.arange(1,self.vis_dim,1)))
        self.theta_z = np.random.rand()*self.theta_z.reshape((self.vis_dim, 1))
        self.theta_u = 0.9*(np.arange(1,self.hid_dim+1,1)/(np.linalg.norm(np.arange(1,self.hid_dim+1,1))))
        self.theta_u = self.theta_u.reshape((self.hid_dim,1))
        # self.arms_z = [np.abs(np.random.normal(size=(self.vis_dim,1))) for i in range(self.num_arms)]
        # self.arms_z = [arm/(np.linalg.norm(arm)) for arm in self.arms_z]
        # self.arms_u = [np.abs(np.random.normal(size=(self.hid_dim,1))) for i in range(self.num_arms)]
        # self.arms_u = [arm/(np.linalg.norm(arm)) for arm in self.arms_u]
        # noise = np.random.normal(size=(self.hid_dim,1))
        # noise = (0.1/np.linalg.norm(noise))*noise
        # arm = np.add(self.theta_u, noise)
        # arm = arm/np.linalg.norm(arm)
        # self.arms_u[0] = arm
        # self.arms_u = [self.theta_u + (i+20)*np.abs(np.random.normal(size=(self.hid_dim,1))) for i in range(self.num_arms)]
        # self.arms_u = [arm/(np.linalg.norm(arm)) for arm in self.arms_u]

        self.arms_u, self.arms_z = [], []
        for i in range(self.num_arms):
            if i == 0:
                noise = np.random.normal(size=(self.hid_dim,1))
                noise = (0.01/np.linalg.norm(noise))*noise
                arm = np.add(self.theta_u, noise)
                arm = arm/np.linalg.norm(arm)
                self.arms_u += [arm]
            else:
                arm = np.zeros((self.hid_dim, 1))
                idx = np.random.choice(range(self.hid_dim), 2, replace=False)
                arm[idx] = 1
                arm = arm/np.linalg.norm(arm)
                self.arms_u += [arm]
            tmp = np.abs(np.random.normal(size=(self.vis_dim,1)))
            self.arms_z += [tmp/(np.linalg.norm(tmp))]
        self.all_means_u = [np.dot(self.theta_u.T, arm) for arm in self.arms_u]
        self.all_means_z = [np.dot(self.theta_z.T, arm) for arm in self.arms_z]
        self.full_means = [self.all_means_u[i] + self.all_means_z[i] for i in range(self.num_arms)]
        self.best_arm = np.argmax(self.full_means)
        self.best_arm_u = np.argmax(self.all_means_u)
        np.random.seed()
        self.gen_bounds()

    def gen_bounds(self):
        self.p = np.zeros(self.num_arms)
        self.nu = np.zeros(self.num_arms)
        M, m = -10*np.ones(self.hid_dim), 10*np.ones(self.hid_dim)
        for i in range(self.log_len):
            temp = np.random.normal(size=(self.hid_dim,1))
            noise = (0.1/np.linalg.norm(temp))*temp
            noise = noise.reshape((self.hid_dim, 1))
            noised_theta_u = np.add(self.theta_u, noise)
            M = np.maximum(M, noised_theta_u[:,0])
            m = np.minimum(m, noised_theta_u[:,0])
            rews_u = [np.dot(noised_theta_u.T, arm) for arm in self.arms_u]
            full_rews = [rews_u[k] + self.all_means_z[k] for k in range(self.num_arms)]
            best = np.argmax(full_rews)
            if best.size>1:
                best = best[0]
            self.p[best] += 1
            scale = min(rews_u[best], (1-rews_u[best]))
            rew_noise =  scale*np.random.uniform(-1,1)
            self.nu[best] += full_rews[best] + rew_noise
        self.p, self.nu = (1/self.log_len)*self.p, (1/self.log_len)*self.nu
        common = [self.nu[k]*self.p[k] + (1-self.p[k])*self.all_means_z[k] for k in range(self.num_arms)]
        M, m = np.amax(M), np.amin(m)
        self.ub = np.asarray([common[k] + (1-self.p[k])*M*self.hid_dim*sum(self.arms_u[k][:,0]) - self.all_means_z[k] for k in range(self.num_arms)])
        self.lb = np.asarray([common[k] + (1-self.p[k])*m*self.hid_dim*sum(self.arms_u[k][:,0]) - self.all_means_z[k] for k in range(self.num_arms)])
        self.new_ub, self.new_lb = self.tighten_bounds()
        self.check_bounds()

    def tighten_bounds(self):
        new_upper, new_lower = np.zeros(len(self.ub)), np.zeros(len(self.lb))
        A = copy.deepcopy(self.arms_u)
        A = np.asarray(A)
        A = A[:,:,0]
        A = np.vstack((A,-1*A))
        b = np.append(self.ub, -1*self.lb)
        for i in range(self.num_arms):
            lims = (-1, 1)
            obj = self.arms_u[i]
            opt1 = linprog(c=obj, A_ub=A, b_ub=b, bounds=lims)
            new_lower[i] = max(0, opt1.fun)
            opt2 = linprog(c=-1*obj, A_ub=A, b_ub=b, bounds=lims)
            new_upper[i] = min(1, -1*opt2.fun)
        return new_upper, new_lower
    
    def check_bounds(self):
        for i in range(self.num_arms):
            if self.new_lb[i]>self.all_means_u[i] or self.new_ub[i]<self.all_means_u[i]:
                print('Problem')


    def gen_online_rews(self):
        temp = np.random.normal(size=(self.hid_dim,1))
        noise = (0.1/np.linalg.norm(temp))*temp
        noise = noise.reshape((self.hid_dim, 1))
        noised_theta_u = np.add(self.theta_u, noise)
        rews_u = [np.dot(noised_theta_u.T, arm) for arm in self.arms_u]
        full_rews = [rews_u[k] + self.all_means_z[k] for k in range(self.num_arms)]
        scale = [min(rews_u[i], (1-rews_u[i])) for i in range(self.num_arms)]
        scale = np.asarray(scale).reshape((self.num_arms))
        rew_noise = np.multiply(scale, np.random.uniform(-1, 1, size=(self.num_arms)))
        full_rews = np.asarray(full_rews).reshape((self.num_arms))
        rews = np.add(full_rews, rew_noise)
        rews = rews.reshape(self.num_arms)
        return rews, self.best_arm


    def gen_rews(self):
        rews = np.zeros(self.num_arms)
        for i in range(self.num_arms):
            noise = min(self.all_means_u[i], 1-self.all_means_u[i])
            rews[i] = self.full_means[i] + np.random.uniform(-1*noise, noise)
        return rews, self.best_arm

    def iterate(self):
        return None

parser = argparse.ArgumentParser()
parser.add_argument('--Vis_Dimension', '-VD', help='Dimension', type=int)
parser.add_argument('--Hid_Dimension', '-HD', help='Dimension', type=int)
parser.add_argument('--Number_of_arms', '-NA', help='Number of arms', type=int)
parser.add_argument('--Number_of_Instances', '-INST', help='Number of instances to run', type=int)
parser.add_argument('--Number_of_Iterations', '-ITER', help='Number of iterations per instance', type=int)
parser.add_argument('--Save_suffix', '-S', help='Suffix to savefile', type=str)
parser.add_argument('--Generation_seed', '-SEED', help='Seed for generating arms and bounds', type=int)
args = parser.parse_args()
vis_dim = args.Vis_Dimension
hid_dim = args.Hid_Dimension
num_arms = args.Number_of_arms
num_inst = args.Number_of_Instances
num_iter = args.Number_of_Iterations
suffix = args.Save_suffix
seed = args.Generation_seed

env = Environmnet(vis_dim, hid_dim, num_arms, seed=seed)


def func(x,m):
    fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)
    gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)
    return gx

sg = dict()
for i in np.arange(0,0.5,0.0005):
    i = round(i, 4)
    x_best = brentq(func, 0.0005, 100, args = i)
    sg[i] = (2/x_best**2)*np.log(i*np.exp((1-i)*x_best) + (1-i)*np.exp(-i*x_best))
sg[0.5] = 0.25


SIGMA = 0.25**0.5
DELTA = 0.01

theta_z =env.theta_z

arms_z = env.arms_z

l1 = LinUCB(hid_dim, SIGMA, DELTA)
l2 = RestrictedLinUCB(hid_dim, DELTA, sg)

reg1 = np.zeros((1, num_iter))
reg2 = np.zeros((1, num_iter))
upper, lower = env.new_ub, env.new_lb
time_l1 = []
time_l2 = []
for inst in range(num_inst):
    l1.restart()
    l2.restart()
    for iter in range(num_iter):
        #t = time.time()
        arms = env.arms_u
        rews, true_best = env.gen_online_rews()
        t1 = time.time()
        #print(f'Finished env setup in {t1-t} seconds')
        l1.iterate(arms, theta_z, arms_z, rews, true_best)
        t2 = time.time()
        time_l1 += [t2-t1]
        t2 = time.time()
        #print(f'Finished LinUCB in {t2-t1} seconds')
        l2.iterate(arms, theta_z, arms_z, upper, lower, rews, true_best)
        t3 = time.time()
        time_l2 += [t3-t2]
        #print(f'Finished RestrictedLinUCB in {t3-t2} seconds')
        #t4= time.time()
        #print(f'Finished generating arms in {t4-t3} seconds')  
    iter_reg1 = np.asarray(l1.get_regret(), dtype='object')
    reg1 = np.vstack((reg1, iter_reg1[1:]))
    iter_reg2 = np.asarray(l2.get_regret(), dtype='object')
    reg2 = np.vstack((reg2, iter_reg2[1:]))

reg1, reg2 = reg1[1:,:], reg2[1:,:]

run_dict = dict()
run_dict['reg1'] = reg1
run_dict['reg2'] = reg2
run_dict['time_l1'] = time_l1
run_dict['time_l2'] = time_l2
run_dict['arms_u'] = env.arms_u
run_dict['arms_z'] = env.arms_z
run_dict['theta_u']  = env.theta_u
run_dict['theta_z']  = env.theta_z
run_dict['ub'] =  env.new_ub
run_dict['lb'] =  env.new_lb

savepath = './Files/LinBanditRun4/regfile'+suffix+'.p'
with open(savepath, 'wb') as f:
    pickle.dump(run_dict, f, protocol=pickle.HIGHEST_PROTOCOL)