import numpy as np
import scipy.optimize as so
import torch
import copy
from tqdm import tqdm
import os


class SmoothIGW:
    def __init__(self, gamma, T, oracle, generator):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.gamma     = gamma    
        self.T          = T
        self.oracle     = oracle
        oracle.to(self.device)
        self.generator  = generator

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.t = 0
    
    def reset(self):
        self.t = 0

    def sample(self, x):
        ahat = self.oracle.compute_ahat(x)
        
        with torch.no_grad():
            a = torch.rand(1, device=self.device)
            prob = 1 / (1 + self.gamma * torch.clamp(self.oracle(x,ahat) - self.oracle(x,a),min=0))        
            unif = torch.rand(1, generator=self.generator, device=self.device)

            shouldexplore = (unif <= prob).long()
            selected_action = (ahat + shouldexplore * (a - ahat)).squeeze().item()
        
        self.t += 1

        return selected_action

    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')

            env.reset()
            self.oracle.reset()
            self.reset()

            for t in range(self.T):

                # receive a context
                x = env.gen_context()

                # sample an action
                price = self.sample(x)

                # take the action and receive response of env
                realization, probability = env.act(x,price)

                # update the oracle and the pricing algorithm
                self.oracle.update(x, price, price*realization)

                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)