import numpy as np
import torch
import copy
import os
from tqdm import tqdm


def gaussian_cdf(x):
    return 0.5 * (1 + torch.erf(x / (0.2 * torch.sqrt(torch.tensor(2.0)))))

def linear_model(cdf, price, value):
    return cdf(price-(0.5+value))

class ONSP:
    def __init__(self, d, gamma, eps, T):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.d          = d
        self.gamma      = gamma
        self.T          = T
        self.eps        = eps
        self.theta      = torch.zeros(self.d, requires_grad=True).to(self.device)
        self.A          = eps * torch.diag(torch.ones(self.d)).to(self.device)
        self.model = linear_model
        self.cdf = gaussian_cdf

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))
    
    def reset(self):
        self.theta = torch.zeros(self.d, requires_grad=True).to(self.device)
        self.A = self.eps * torch.diag(torch.ones(self.d)).to(self.device)
    
    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')
            
            env.reset()
            self.reset()
            
            for t in range(self.T):

                # receive a context
                x = env.gen_context()
                
                grid = torch.linspace(0, 1, 100).to(self.device)
                price = grid[torch.argmax(grid * (1-self.model(self.cdf, grid, torch.dot(x, self.theta))))].item()

                # take the action and receive response of env
                realization, probability = env.act(x,price)

                # update the oracle
                loss = lambda theta: -realization * torch.log(1-self.model(self.cdf, price, torch.dot(x, theta)+1e-3)) \
                    - (1-realization) * torch.log(self.model(self.cdf, price, torch.dot(x, theta))+1e-3)
                del_matrix = torch.autograd.functional.jacobian(loss, self.theta)
                self.A += torch.outer(del_matrix, del_matrix)
                self.theta -= 1/self.gamma * torch.linalg.solve(self.A, del_matrix)
                if torch.norm(self.theta) > 1: self.theta /= torch.norm(self.theta)

                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)


        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)