import numpy as np
import scipy
import scipy.optimize as so
import torch
import copy
import os
from tqdm import tqdm

def normal_pdf(x):
    return scipy.stats.norm.pdf(x, loc = 0, scale = 1)

def normal_cdf(x):
    return scipy.stats.norm.cdf(x, loc = 0, scale = 1)

class RMLP2:
    def __init__(self, d, W, T):
        
        self.d  = d
        self.pdf        = normal_pdf
        self.cdf        = normal_cdf
        self.W          = W  
        self.T          = T
        self.u_F        = max(self.pdf(-2*self.W)/(self.cdf(-2*self.W)+1e-9), self.pdf(2*self.W) / (1 - self.cdf(2*self.W)+1e-9))
        self.device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.tau = 1
    
    def l1_norm_constraint(self, x):
        return self.W - np.sum(np.abs(x))
    
    def init_buffer(self, n):
        self.X = np.zeros((n,self.d))
        self.P = np.zeros(n)
        self.Y = np.zeros(n)
    
    def loss(self, theta):
        F = self.cdf(theta[0] * (self.P - 0.5 - np.matmul(self.X, theta[1:]) ) )
        return -np.sum(self.Y * np.log(1-F+1e-3) + (1-self.Y) * np.log(F+1e-3)) / self.tau \
            + 4 * self.u_F * np.sqrt(np.log(self.d) * 2 / self.tau) * np.sum(np.abs(theta[0]*theta[1:]))
    
    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')
            env.reset()
            
            self.tau = 1
            t_in = 0
            self.init_buffer(int(self.tau))
            initial_guess = np.zeros((self.d+1))
            initial_guess[0] = 0.2
            constraint = {'type': 'ineq', 'fun': self.l1_norm_constraint}

            for t in range(self.T):
                if t == 2*self.tau-1:
                    ret = so.minimize(self.loss, initial_guess, method='COBYLA', constraints = constraint)
                    beta = ret.x[0]
                    mu = ret.x[1:] * beta
                    self.tau *= 2
                    t_in = 0
                    self.init_buffer(int(self.tau))
                    #print(beta, mu, self.loss(ret.x))
                
                x = env.gen_context()  # receive context
                x_np = x.numpy(force=True)
                # calculate price
                if t == 0:
                    price = 0
                else:
                    estimation = np.dot(mu, x_np) + beta*0.5
                    grid = np.linspace(0, 1, 20)
                    initial = grid[np.argmax(grid * (1-self.cdf(beta*grid-estimation)))]
                    f = lambda y: y - (1-self.cdf(y))/(self.pdf(y)+1e-9) + estimation
                    root = so.root(f, beta*initial-estimation)
                    if root.success:
                        price = np.clip((root.x[0] + estimation) / beta, 0, 1)
                    else:
                        price = initial
                    #print(f'price={price} est={estimation}')
                
                realization, probability = env.act(x, price)  # get response of env
                self.X[t_in] = x_np
                self.P[t_in] = price
                self.Y[t_in] = realization
                t_in += 1
                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)
