import numpy as np
from itertools import product
import scipy.optimize as opt
import argparse
from multiprocessing import Pool

def drrenyi_operator(P, V, radius, renyi_k):
    k = renyi_k
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)

    res = opt.fminbound(
        lambda eta: -eta + C * np.dot(P,
                                      np.clip(eta - V, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1e3,
        1e3,
        full_output=True,
    )
    return -res[1]

class OptionMDP():
    def __init__(self, H=5, K=100, d=12, tau=0.5, coef=1.02, decimal=1, normalize_reward=False, radius = 0.05, renyi_k = 2.0, gamma = 0.95):
        x_min, x_max = 80, 140
        self.x2s = lambda x: round((x - x_min) * 10**decimal)
        self.s2x = lambda s: s / 10**decimal + x_min

        self.N_S = (x_max - x_min)*10**decimal + 2  # [x_min, x_max] + absorb state
        self.N_A = 2  # exercise or not
        self.d = d  # linear function dimension
        self.H = H  # time horizon
        self.radius = radius
        self.renyi_k = renyi_k
        self.p = 0.5
        self.gamma = gamma
        c_u = 1.02
        c_d = 0.98

        phi = np.zeros((self.N_S, d))
        P = np.zeros((self.N_S, self.N_A, self.N_S))
        r = np.zeros((self.N_S, self.N_A))
        for s in range(self.N_S):
            if s == self.N_S - 1:
                P[s, :, s] = 1
                phi[s, :] = 1  # meaningless
            else:
                P[s, 0, min(self.x2s(c_u * self.s2x(s)), self.N_S - 2)] = tau
                P[s, 0, max(self.x2s(c_d * self.s2x(s)), 0)] = 1 - tau
                P[s, 1, -1] = 1
                r[s, 1] = max(0, K - self.s2x(s))
                # construct the feature
                # radius = (x_max - x_min) / d / 4
                # x_0 = np.linspace(self.s2x(1), x_max, d, endpoint=True)
                # phi[s] = np.exp(-(self.s2x(s) - x_0)**2 / (2 * radius**2))
                radius = (x_max - x_min) / (d - 1)
                x_0 = np.linspace(x_min, x_max, d, endpoint=True)
                phi[s] = np.maximum(1 - np.abs(self.s2x(s) - x_0) / radius, 0)
        self.phi = phi / phi.sum(-1, keepdims=True)
        self.P = P
        self.r = r
        if normalize_reward:
            self.r_scale = self.r.max() * 10
            self.r /= self.r_scale

        self.initial_state_dist = np.zeros(self.N_S)
        self.initial_state_dist[self.x2s(K - 5):self.x2s(K + 5)] = 1
        self.initial_state_dist /= self.initial_state_dist.sum()  # uniform reset

        self.Q = [np.zeros((self.N_S, self.N_A)) for _ in range(self.H+1)]
        
    def reset(self):
        self.h = 0
        self.state = np.random.choice(self.N_S, p=self.initial_state_dist)
        return self.state, self.h
    
    def set_state(self, state):
        self.state = state

    def step(self, action):
        self.h += 1
        reward = self.r[self.state, action]
        p = self.P[self.state, action]
        p /= p.sum()
        self.state = np.random.choice(self.N_S, p=p)
        return self.state, reward, self.h >= self.H, self.h
    
    def step_batch(self, state, action):
        N = np.random.geometric(self.p) - 1
        p_N = self.p * ((1 - self.p) ** N)
        sample_num = 2 ** (N + 1)
        p = self.P[state, action]/self.P[state, action].sum()
        s_prime = np.random.choice(self.N_S, p=p, size = (sample_num))
        r = np.repeat(self.r[state, action], sample_num)
        return s_prime, r, N, p_N, sample_num

    def step_multiple(self, action, sample_num):
        # self.h += 1
        reward = np.repeat(self.r[self.state, action], sample_num)
        state = np.random.choice(self.N_S, size = (sample_num), p=self.P[self.state, action])
        self.state = state[0]
        return state, reward, False, self.h >= self.H
    
    def compute_V(self):
        file_name = f'Renyi({self.renyi_k})_{self.radius}_optimal'
        # print(f'{self.renyi_k} and {self.radius}')
        self.Q[-1][:, 0] = self.r[:, 1]
        self.Q[-1][:, 1] = self.r[:, 1]
        for i in range(self.H - 1, -1, -1):
            for s, a in product(range(self.N_S), range(self.N_A)):
                if a == 0:
                    self.Q[i][s,a] =  self.r[s, a] + self.gamma * drrenyi_operator(
                        self.P[s, a],
                        self.Q[i+1].max(axis=-1),
                        self.radius,
                        self.renyi_k
                    )
                else:
                    self.Q[i][s,a] = self.r[s, a]
            # print(f"finish {i}th iteration")
        # self.Q[0] = np.maximum(self.Q[0], 0)
        value_init = np.dot(self.initial_state_dist, self.Q[0].max(-1))
        np.save(f"./results/{file_name}", value_init)
        np.save(f"./models/{file_name}", self.Q)
        return value_init
    
def run(delta, renyi_k):
    OptionEnv = OptionMDP(radius = delta, renyi_k = renyi_k)
    value = OptionEnv.compute_V()
    return value

    
if __name__ == '__main__':
    parser = argparse.ArgumentParser() 
    parser.add_argument("--num_process", default= 20, type = int)
    args = parser.parse_args()

    # print(np.dot(OptionEnv.initial_state_dist, OptionEnv.compute_V()))

    deltas  = [0.01, 0.02, 0.03, 0.05, 0.1, 0.2, 0.4, 0.5] # args.radius # 0.05 #1
    renyi_ks = [1.5, 2.0, 3.0, 4.0]
    with Pool(processes=args.num_process) as pool:
        pool.starmap(run, product(deltas, renyi_ks))


    # print(run(delta = 0.05, renyi_k= 2.0))