import numpy as np
import scipy.optimize as so
import torch
import copy
from tqdm import tqdm
import os


class SquareCB:
    def __init__(self, gamma, T, oracle, generator, log=True):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.log = log

        self.gamma     = gamma    
        self.T          = T
        self.oracle     = oracle
        oracle.to(self.device)
        self.K = 100
        self.generator  = generator

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.t = 0
    
    def reset(self):
        self.t = 0
    
    def best_arm(self, x):
        a = torch.linspace(0,1,self.K).reshape(-1,1).to(self.device)
        with torch.no_grad():
            pred = self.oracle(torch.tile(x,(self.K,1)),a)
            idx = torch.argmax(pred).squeeze().item()
        return idx, a[idx].item()

    def sample(self, x):
        idx, b_t = self.best_arm(x)
        gamma = self.gamma
        
        with torch.no_grad():
            a = torch.linspace(0,1,self.K).reshape(-1,1).to(self.device)
            prob = 1 / (self.K + gamma * (self.oracle(x,b_t) - self.oracle(torch.tile(x,(self.K,1)),a))).flatten()
            prob[idx] = 1 - (torch.sum(prob)-prob[idx])
            selected_action = torch.multinomial(prob, 1, generator=self.generator).item()

        self.t += 1

        return a[selected_action].item()

    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')

            env.reset()
            self.oracle.reset()
            self.reset()

            for t in range(self.T):

                # receive a context
                x = env.gen_context()

                # sample an action
                price = self.sample(x)

                # take the action and receive response of env
                realization, probability = env.act(x,price)

                # update the oracle and the pricing algorithm
                self.oracle.update(x, price, price*realization)

                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)
            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)
