import numpy as np
import torch
from scipy.linalg import lstsq
import os
import copy


class ExUCB:
    def __init__(self, l0, C1, d, T, generator):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.beta = 2/3
        self.gamma = 1/6
        self.l0 = l0
        self.C1 = C1
        self.C2 = 20
        self.Cu = 0.1
        self.lam = 0.1
        self.d = d
        self.T = T
        self.generator  = generator

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.t = 0
    
    def init_buffer(self, n):
        self.X = np.zeros((n,self.d+1))
        self.A = np.zeros(n)
        self.Y = np.zeros(n)
    
    def init_ucb(self, theta, d, T0):
        self.m = np.linspace(-np.sum(np.abs(theta)), 1 + np.sum(np.abs(theta)), d)
        self.D = np.zeros((T0,d))
        self.p = np.zeros((T0,))
        self.y = np.zeros((T0,))
    
    def ucb(self, j, beta_t):
        if np.sum(self.D[:,j])==0:
            return 1e9
        else:
            term1 = np.sum(self.D[:,j] * self.p**2 * self.y) / (self.lam + np.sum(self.D[:,j] * self.p**2))
            term2 = np.sqrt(beta_t / (self.lam + np.sum(self.D[:,j] * self.p**2)))
            return term1 + self.Cu * term2
    
    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')
            env.reset()

            l = self.l0/2
            t = 0
            theta = np.zeros((self.d))
            while t < self.T:
                l = int(l*2)
                l_expr = int(np.ceil(self.C1 * np.power(l, self.beta)))
                self.init_buffer(l_expr)
                for t_in in range(l_expr):
                    if t >= self.T:
                        break
                    # receive a context
                    x = env.gen_context()
                    # random price
                    price = torch.rand(1, generator=self.generator, device=self.device).item()
                    # take the action and receive response of env
                    realization, probability = env.act(x,price)
                    # store the result
                    self.X[t_in] = np.concatenate([x.numpy(force=True),np.array([1])])
                    self.A[t_in] = price
                    self.Y[t_in] = realization
                    # log data
                    self.reward[t] = price*probability
                    _, self.optimal_reward[t] = env.optimal_action(x)
                    t += 1

                theta, _, _, _ = lstsq(self.X, self.Y)
                theta = theta[:self.d]
                T0 = max(l - l_expr, 0)
                d = int(np.ceil(self.C2 * np.power(T0, self.gamma)))
                self.init_ucb(theta, d, T0)

                for t_in in range(T0):
                    if t >= self.T:
                        break
                    # receive a context
                    x = env.gen_context()

                    # select an arm (price) by calculating UCB
                    beta_t = np.maximum(1, (np.sqrt(self.lam*self.d) + \
                                             np.sqrt(2*np.log(T0) + self.d*np.log((self.lam*self.d+t_in)/(self.lam*self.d))))**2 )
                    S = self.m + np.dot(x.numpy(force=True), theta)
                    UCB = np.array([self.ucb(j, beta_t) for j in range(d)]) * (S>=0) * (S<=1)
                    max_util = np.max(S*UCB)
                    idx = np.random.choice(np.argwhere(S*UCB>=max_util-1e-4).reshape(-1), 1)[0]
                    price = S[idx]
                    self.D[t_in, idx] = 1

                    # take the action and receive response of env
                    realization, probability = env.act(x,price)
                    # log data
                    self.reward[t] = price*probability
                    self.p[t_in] = price
                    self.y[t_in] = realization
                    _, self.optimal_reward[t] = env.optimal_action(x)
                    t += 1

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)