import numpy as np
import torch
import copy
import os
from time import time


class DP_IGW:
    def __init__(self, gamma0, T, oracle, generator, gamma_scheduling=True):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.gamma_scheduling = gamma_scheduling
        self.gamma0 = gamma0
        self.gamma = gamma0
        self.eta = 2
        self.T          = T
        self.oracle     = oracle
        oracle.to(self.device)
        self.generator  = generator

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))
        self.expr           = np.zeros((T,))

        self.t = 0
    
    def reset(self):
        self.t = 0

    def sample(self, x):
        phat = self.oracle.compute_phat(x)

        p = torch.rand(1, device=self.device)
        with torch.no_grad():
            
            prob = 1 / (1 + self.gamma * torch.clamp(phat*self.oracle(x,phat) - p*self.oracle(x,p),min=0))        
            unif = torch.rand(1, generator=self.generator, device=self.device)

            shouldexplore = (unif <= prob).long()
            selected_action = (phat + shouldexplore * (p - phat)).squeeze().item()
        
        self.expr[self.t] = shouldexplore.item()
        self.t += 1

        return selected_action
    
    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        times = np.zeros((rep))
        
        for r in range(rep):
            print(f'run {r}')
            start_time = time()

            k = 1
            tau = int(np.ceil(np.power(self.eta,k)))
            tau_sum = tau
            self.gamma = self.gamma0
            env.reset()
            self.oracle.reset()
            self.reset()

            for t in range(self.T):
                if t==tau_sum and self.gamma_scheduling:
                    k += 1
                    tau = int(np.ceil(np.power(self.eta,k)))
                    tau_sum += tau
                    self.gamma *= np.power(self.eta,1/3)

                # receive a context
                x = env.gen_context()

                # sample an action
                price = self.sample(x)

                # take the action and receive response of env
                realization, probability = env.act(x,price)

                # update the oracle and the pricing algorithm
                self.oracle.update(x, price, realization)

                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)
                
            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)
            times[r] = time() - start_time

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)
        np.save(basedir+'/time.npy', times)
        print(f'elapsed time: {np.mean(times)} \pm {np.std(times)}')