import math
import numpy as np
from numpy.linalg import inv
from scipy.linalg import lstsq
import torch
import os
import time
import copy

class Kernel:
    def __init__(self,v_i,y_t,n):
        self.v_i = np.array(v_i)
        self.y_t = np.array(y_t)
        self.n = np.array(n)
        self.h = np.array(3*n**(-1/5))

    def sec_kernel_h(self,t):
        vec = self.kernel((self.v_i.reshape(-1,1) - np.array(t).reshape(1,-1) )/self.h)
        return 1/(self.n*self.h) * np.sum(vec * self.y_t.reshape(-1,1), axis=0)
    
    def sec_kernel_f(self,t):
        vec = self.kernel((self.v_i.reshape(-1,1) - np.array(t).reshape(1,-1) )/self.h)
        return 1/(self.n*self.h) * np.sum(vec, axis=0)
    
    def kernel(self, x):
        return 35/12 * (1-x**2)**5*((abs(x)<=1)*1)
        
    def sec_kernel_whole1(self,t):
        return self.sec_kernel_h(t)/(self.sec_kernel_f(t)+1e-9)

    # custom addition
    def cdf0_est(self,t):
        return 1 - self.sec_kernel_whole1(t)
    
    def cdf_est(self, t, thetax):
        return self.cdf0_est(t - thetax)
    
    def revenue(self, thetax, x):
        return(x * (1 - self.cdf_est(x, thetax)))

    def maxi_est(self, thetax, low, upp):
        p_grid = np.arange(low, upp+0.01, 0.01)
        ind_max = np.argmax( self.revenue(thetax, p_grid) )
        p_opt = p_grid[ind_max]
        return(p_opt)


class Fan:
    def __init__(self, l0, C1, d, T, generator):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.m = 2
        self.l0 = l0
        self.C1 = C1
        self.lam = 0.1
        self.d = d
        self.T = T
        self.generator  = generator

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.t = 0
    
    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))
        
        for r in range(rep):
            print(f'run {r}')
            env.reset()

            l = self.l0/2
            t = 0
            # exploration phase
            while t < self.T:
                l = int(l*2)
                l_expr = int(np.ceil(self.C1 * np.power(l*self.d, (2*self.m+1)/(4*self.m-1))))
                X = np.zeros((l_expr,self.d+1))
                P = np.zeros(l_expr)
                Y = np.zeros(l_expr)
                for t_in in range(l_expr):
                    if t >= self.T:
                        break
                    # receive a context
                    x = env.gen_context()
                    # random price
                    price = torch.rand(1, generator=self.generator, device=self.device).item()
                    # take the action and receive response of env
                    realization, probability = env.act(x,price)
                    # store the result
                    X[t_in] = np.concatenate([x.numpy(force=True),np.array([1])])
                    P[t_in] = price
                    Y[t_in] = realization
                    # log data
                    self.reward[t] = price*probability
                    _, self.optimal_reward[t] = env.optimal_action(x)
                    t += 1
                
                # regression and kernel initialization
                theta, _, _, _ = lstsq(X, Y)
                v = P - np.dot(X, theta)
                ker = Kernel(v, Y, l_expr)
                l_expt = l - l_expr
                
                # exploitation phase
                for t_in in range(l_expt):
                    if t >= self.T:
                        break
                    # receive a context
                    x = env.gen_context()

                    # compute price based on the kernel
                    v_est = theta[0] + np.dot(theta[1:], x.numpy(force=True))
                    price = ker.maxi_est(v_est, 0, 1)

                    # take the action and receive response of env
                    realization, probability = env.act(x,price)
                    # log data
                    self.reward[t] = price*probability
                    _, self.optimal_reward[t] = env.optimal_action(x)
                    t += 1

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)
    
