import numpy as np
from env import OptionMDP
import scipy.optimize as opt
import copy
import argparse
import time
import os
from itertools import product
from multiprocessing import Pool

max_sample = 1e6

def drrenyi_operator(P, V, discount, radius, renyi_k):
    k = renyi_k
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)

    res = opt.fminbound(
        lambda eta: -eta + C * np.dot(P,
                                      np.clip(eta - V, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1e3,
        1e3,
        full_output=True,
    )
    return -res[1]

# def tri_search(s_sample, upper, rob_q_value, delta):
#     v_value_rob_sample = np.zeros(len(s_sample))
#     for s in range(len(s_sample)):
#         v_value_rob_sample[s] = np.max(rob_q_value[s_sample[s]])
#     min_v_value_rob_sample = np.min(v_value_rob_sample)
#     max_v_value_rob_sample = np.max(v_value_rob_sample)
#     diff_v_value_rob_sample = -v_value_rob_sample + min_v_value_rob_sample
#     b_max = (max_v_value_rob_sample - min_v_value_rob_sample) / delta
#     b_min = 1e-5
#     f1 = 0
#     f2 = 1
#     # while (b_min + 1e-3 < b_max):
#     num_iter = 0
#     while (np.abs(f1 - f2) > 1e-5) and (num_iter < 1e+3):
#         b1 = (b_min + b_max) / 2
#         b2 = (b1 + b_max) / 2
#         total1 = np.mean(np.exp(diff_v_value_rob_sample / b1))
#         total2 = np.mean(np.exp(diff_v_value_rob_sample / b2))
#         f1 = b1 * np.log(total1) + b1 * delta - min_v_value_rob_sample
#         f2 = b2 * np.log(total2) + b2 * delta - min_v_value_rob_sample
#         if f1 < f2:
#             b_max = b2
#         else:
#             b_min = b1
#         num_iter += num_iter
#     return f1

def tri_search_r(r_sample, R, delta):
    min_r_sample = np.min(r_sample)
    diff_r_sample = -r_sample + min_r_sample
    b_max = R / delta
    b_min = 1e-5
    f1 = 0
    f2 = 1
    # while (b_min + 1e-3 < b_max):
    num_iter = 0
    while (np.abs(f1 - f2) > 1e-5) and (num_iter < 1e+3):
        b1 = (b_min + b_max) / 2
        b2 = (b1 + b_max) / 2
        total1 = np.mean(np.exp(diff_r_sample/ b1))
        total2 = np.mean(np.exp(diff_r_sample / b2))
        f1 = b1 * np.log(total1) + b1 * delta - min_r_sample
        f2 = b2 * np.log(total2) + b2 * delta - min_r_sample
        if f1 < f2:
            b_max = b2
        else:
            b_min = b1
        num_iter += num_iter
    return f1

def tri_search_renyi(s_primes, rob_q_value, radius, renyi_k):
    k, n = renyi_k, len(s_primes)
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)
    v_value_rob_sample = np.zeros(n)
    for i, s_prime in enumerate(s_primes):
        v_value_rob_sample[i] = np.max(rob_q_value[s_prime])    
    res = opt.fminbound(
        lambda eta: -eta + C * np.mean(
                                      np.clip(eta - v_value_rob_sample, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1e3,
        1e3,
        full_output=True,
    )
    return -res[1]

# def tri_search_r_renyi(r, discount, radius, renyi_k):
#     k = renyi_k
#     C = (1 + k * (k - 1) * radius)**(1 / k)
#     k_star = k / (k - 1)

#     res = opt.fminbound(
#         lambda eta: -eta + C * np.mean(
#                                       np.clip(eta - r, a_max=None, a_min=0)**k_star)**(1 / k_star),
#         -1e3,
#         1e3,
#         full_output=True,
#     )
#     return  -res[1]

class DRQM:
    def __init__(self, n, delta, renyi_k, H, gamma, k = 0) -> None:
        self.n = n
        self.H = H
        self.rob_q_value = [np.zeros([n, n]) for _ in range(self.H+1)]
        self.diff = [np.zeros([n, n]) for _ in range(self.H+1)]
        self.temp_rob_q_value = [np.zeros([self.n, self.n]) for _ in range(self.H+1)]
        self.delta = delta
        self.renyi_k = renyi_k
        self.gamma = gamma
        self.k = k

    def init_q(self, env):
        for h in range(self.H - 1, -1, -1):
            for s in range(self.n):
                self.rob_q_value[h][s, 1] = env.r[s, 1]

        self.rob_q_value[-1][:, 0] = env.r[:, 1]
        self.rob_q_value[-1][:, 1] = env.r[:, 1]

    def update_q(self, s, a, s_prime, r, p_N, N, niter, h):
        alpha = 1 / (1 + (1 - self.gamma) * niter)
        base_r = -tri_search_r(r[:2 ** self.k], self.C, self.delta)
        base_q = -tri_search(s_prime[:2 ** self.k],self.upper,self.rob_q_value, self.delta)
        delta_r = tri_search_r(r, self.C, self.delta) - 0.5 * tri_search_r(r[:2 ** N], self.C, self.delta) - 0.5 * tri_search_r(r[2 ** N:], self.C, self.delta)
        delta_r = - delta_r
        delta_q = tri_search(s_prime,self.upper,self.rob_q_value, self.delta) - 0.5 * tri_search(s_prime[:2 ** N],self.upper, self.rob_q_value, self.delta) - 0.5 * tri_search(s_prime[2 ** N:],self.upper,self.rob_q_value, self.delta)
        delta_q = -delta_q

        self.temp_rob_q_value[h][s, a] = (1 - alpha) * self.rob_q_value[h][s, a] + alpha * (
                    base_r + self.gamma * base_q + (delta_r + self.gamma * delta_q) / p_N)
        # self.diff[][s,a] = abs(self.temp_rob_q_value[s,a] - self.rob_q_value[s,a])
        self.rob_q_value[h][s,a] = self.temp_rob_q_value[h][s, a]
        return None # self.diff

    def update_q_renyi(self, s, a, s_prime, r, p_N, N, niter, h):
        alpha = 1/(1 + (1 - self.gamma) * niter) # 0.01 # 1 / (1 + (1 - self.gamma) * niter)
        base_r = r[0] # tri_search_r_renyi(r[:2 ** self.k], self.gamma, self.delta, self.renyi_k)
        base_q = tri_search_renyi(s_prime[:2 ** self.k], self.rob_q_value[h+1], self.delta, self.renyi_k)
        delta_r = 0 # tri_search_r_renyi(r,  self.delta, self.renyi_k) - 0.5 * tri_search_r_renyi(r[:2 ** N],  self.delta, self.renyi_k) - 0.5 * tri_search_r_renyi(r[2 ** N:],  self.delta, self.renyi_k)
        delta_q = tri_search_renyi(s_prime,  self.rob_q_value[h+1], self.delta, self.renyi_k) - 0.5 * tri_search_renyi(s_prime[:2 ** N],  self.rob_q_value[h+1], self.delta, self.renyi_k) - 0.5 * tri_search_renyi(s_prime[2 ** N:], self.rob_q_value[h+1], self.delta, self.renyi_k)

        self.temp_rob_q_value[h][s, a] = (1 - alpha) * self.rob_q_value[h][s, a] + alpha * (base_r + self.gamma * base_q + (delta_r + self.gamma * delta_q) / p_N)
        self.diff[h][s,a] = abs(self.temp_rob_q_value[h][s,a] - self.rob_q_value[h][s,a])
        self.rob_q_value[h][s,a] = copy.copy(self.temp_rob_q_value[h][s, a])
        return None # self.diff

def sync_q_learning(delta, renyi_k, max_iter, seed = 0):
    np.random.seed(seed)
    total_sample_num = 0
    niter = 0
    env = OptionMDP(radius = delta, renyi_k= renyi_k)
    drqm = DRQM(n = env.N_S, gamma = env.gamma, delta = delta, renyi_k = renyi_k, H = env.H)
    drqm.init_q(env)
    evaluations = list()
    file_name = f"DRQM_radius_{delta}_renyik_{renyi_k}_seed_{seed}"
    value_init = np.dot(env.initial_state_dist, drqm.rob_q_value[0].max(-1))
    evaluations.append([total_sample_num, value_init])
    while True:
        for h in range(env.H):
            for s in range(env.N_S):
                a = 0
                s_prime, r, N, p_N, sample_num = env.step_batch(s, a)
                total_sample_num += sample_num
                # diff = drqm.update_q(s, a, s_prime, r, p_N, N, niter)
                drqm.update_q_renyi(s, a, s_prime, r, p_N, N, niter, h)
                # drqm.rob_q_value[h][s, 1] = env.r[s, 1]

        niter += 1
        if total_sample_num > max_iter:
            break
        #     print(np.max(drqm.rob_q_value, axis=1))
        #     print(np.argmax(drqm.rob_q_value, axis=1))
        #     print("number of iterations: ", niter)
        #     print("number of samples: ", total_sample_num)
        #     print("---Sync DRO Q finished---")
        #     return drqm.rob_q_value
        if niter % 1 == 0:
            value_init = np.dot(env.initial_state_dist, drqm.rob_q_value[0].max(-1))
            evaluations.append([total_sample_num, value_init])
            # print(np.argmax(drqm.rob_q_value, axis=1))
            np.save(f"./results/{file_name}", evaluations)
            np.save(f"./models/{file_name}", drqm.rob_q_value)
            if args.multiple == False:
                print("number of iterations: ", niter)
                print("number of samples: ", total_sample_num)
                print("value: ", value_init)

class DRQ:
    def __init__(self, reward_scale, state_dim, action_dim, expl_noise, delta, k, H, gamma = 0.95) -> None:
        self.reward_scale = reward_scale
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.expl_noise = expl_noise
        self.radius = delta
        self.k = k
        self.H = H
        self.gamma = gamma
        
        self.rob_q_value = [np.random.randn(self.state_dim, self.action_dim) for _ in range(self.H+1)] # [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)] # np.random.randn(self.state_dim, self.action_dim) / self.reward_scale
        self.Z1 = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.Z2 = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.E = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]

        self.rob_q_value_ = [np.random.randn(self.state_dim, self.action_dim) for _ in range(self.H+1)] # [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)] # np.random.randn(self.state_dim, self.action_dim) / self.reward_scale
        self.Z1_ = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.Z2_ = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.E_ = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]

    def decide(self, state, h):
        if np.random.rand() < self.expl_noise:
            action = np.random.randint(self.action_dim)
        else:
            action = np.argmax(self.rob_q_value[h][state])
        return action
    
    def init_q(self, env):
        for i in range(self.H - 1, -1, -1):
            for s in range(self.state_dim):
                self.rob_q_value[i][s, 1] = env.r[s, 1]

        self.rob_q_value[-1][:, 0] = env.r[:, 1]
        self.rob_q_value[-1][:, 1] = env.r[:, 1]

    def init_double_q(self, env):
        for i in range(self.H - 1, -1, -1):
            for s in range(self.state_dim):
                self.rob_q_value[i][s, 1] = env.r[s, 1]
                self.rob_q_value_[i][s, 1] = env.r[s, 1]


        self.rob_q_value[-1][:, 0] = env.r[:, 1]
        self.rob_q_value[-1][:, 1] = env.r[:, 1]

        self.rob_q_value_[-1][:, 0] = env.r[:, 1]
        self.rob_q_value_[-1][:, 1] = env.r[:, 1]

    def update_q(self, state, action, reward, next_state, t, h):
        # Learning rates
        def zeta1(t):  # fast
            return 1 /(1 + (1 - self.gamma)*t**(0.6)) 

        def zeta2(t):  # medium
            return 1/ (1 + (1 - self.gamma)/10*t**(0.8)) 

        def zeta3(t):  # slow
            return 1/(1 + (1 - self.gamma)/100*t) # 1/(1 + (1 - self.gamma)/100 * t)
        
        if action == 0:
            # reward /= self.reward_scale
            C = (1 + self.k * (self.k - 1) * self.radius)**(1 / self.k)
            k_star = self.k / (self.k - 1)
            # Unpack values
            z1, z2 = self.Z1[h][state, action], self.Z2[h][state, action]
            eta = self.E[h][state, action]
            z2 = max(z2, 1e-8)  # For computation stability

            next_value = np.max(self.rob_q_value[h+1][next_state])
            self.Z1[h][state, action] += zeta1(t) * (max(eta - next_value, 0)**(k_star - 1) - z1)
            self.Z2[h][state, action] += zeta1(t) * (max(eta - next_value, 0)**k_star - z2)

            # Update beta
            D = 1 - C * z1 * z2**(1 / k_star - 1)
            self.E[h][state, action] += zeta2(t) * D

            # Update Q
            robust_value = eta - C * z2**(1 / k_star)
            self.rob_q_value[h][state, action] += zeta3(t) * (reward + self.gamma * robust_value - self.rob_q_value[h][state, action])
        else:
            self.rob_q_value[h][state, action] = reward 

    def update_q_double(self, state, action, reward, next_state, t, h):
            # Learning rates
            def zeta1(t):  # fast
                return 0.5 # 1 /(1 + (1 - self.discount)*t**(0.6)) 

            def zeta2(t):  # medium
                return 0.1 # 1/ (1 + (1 - self.discount)/10*t**(0.8)) 

            def zeta3(t):  # slow
                return 0.005 # 1/(1 + (1 - self.discount)/100 * t)
                
            # reward /= self.reward_scale
            if action == 0:
                C = (1 + self.k * (self.k - 1) * self.radius)**(1 / self.k)
                k_star = self.k / (self.k - 1)
                # Unpack values
                z1, z2 = self.Z1[h][state, action], self.Z2[h][state, action]
                z1_, z2_ = self.Z1_[h][state, action], self.Z2_[h][state, action]

                eta = self.E[h][state, action]
                eta_ = self.E_[h][state, action]

                z2 = max(z2, 1e-8)  # For computation stability
                z2_ = max(z2_, 1e-8)  # For computation stability

                next_value = self.rob_q_value[h][next_state, np.argmax(self.rob_q_value_[h][next_state])]
                next_value_ = self.rob_q_value_[h][next_state, np.argmax(self.rob_q_value[h][next_state])]

                self.Z1[h][state, action] += zeta1(t) * (max(eta - next_value, 0)**(k_star - 1) - z1)
                self.Z2[h][state, action] += zeta1(t) * (max(eta - next_value, 0)**k_star - z2)
                # Update beta
                D = 1 - C * z1 * z2**(1 / k_star - 1)
                self.E[h][state, action] += zeta2(t) * D

                # Update Q
                robust_value = eta - C * z2**(1 / k_star)

                ##################################
                ### update double Q
                self.Z1_[h][state, action] += zeta1(t) * (max(eta_ - next_value_, 0)**(k_star - 1) - z1_)
                self.Z2_[h][state, action] += zeta1(t) * (max(eta_ - next_value_, 0)**k_star - z2_)
                # Update beta
                D_ = 1 - C * z1_ * z2_**(1 / k_star - 1)
                self.E_[h][state, action] += zeta2(t) * D_

                # Update Q
                robust_value_ = eta_ - C * z2_**(1 / k_star)

                # change to double-q learning     
                self.rob_q_value[h][state, action] += zeta3(t) * (reward + self.gamma * robust_value - self.rob_q_value[h][state, action])
                self.rob_q_value_[h][state, action] += zeta3(t) * (reward + self.gamma * robust_value_ - self.rob_q_value_[h][state, action])
            else:
                self.rob_q_value[h][state, action] = reward 
                self.rob_q_value_[h][state, action] = reward 

class Q:
    def __init__(self, reward_scale, state_dim, action_dim, expl_noise, delta, k, H, gamma = 0.95) -> None:
        self.reward_scale = reward_scale
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.expl_noise = expl_noise
        self.radius = delta
        self.k = k
        self.H = H
        self.gamma = gamma
        
        self.rob_q_value = [np.random.randn(self.state_dim, self.action_dim) for _ in range(self.H+1)] # [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)] # np.random.randn(self.state_dim, self.action_dim) / self.reward_scale
        self.Z1 = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.Z2 = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.E = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]

        self.rob_q_value_ = [np.random.randn(self.state_dim, self.action_dim) for _ in range(self.H+1)] # [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)] # np.random.randn(self.state_dim, self.action_dim) / self.reward_scale
        self.Z1_ = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.Z2_ = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]
        self.E_ = [np.zeros((self.state_dim, self.action_dim)) for _ in range(self.H+1)]

    def decide(self, state, h):
        if np.random.rand() < self.expl_noise:
            action = np.random.randint(self.action_dim)
        else:
            action = np.argmax(self.rob_q_value[h][state])
        return action
    
    def init_q(self, env):
        for i in range(self.H - 1, -1, -1):
            for s in range(self.state_dim):
                self.rob_q_value[i][s, 1] = env.r[s, 1]

        self.rob_q_value[-1][:, 0] = env.r[:, 1]
        self.rob_q_value[-1][:, 1] = env.r[:, 1]

    def init_double_q(self, env):
        for i in range(self.H - 1, -1, -1):
            for s in range(self.state_dim):
                self.rob_q_value[i][s, 1] = env.r[s, 1]
                self.rob_q_value_[i][s, 1] = env.r[s, 1]


        self.rob_q_value[-1][:, 0] = env.r[:, 1]
        self.rob_q_value[-1][:, 1] = env.r[:, 1]

        self.rob_q_value_[-1][:, 0] = env.r[:, 1]
        self.rob_q_value_[-1][:, 1] = env.r[:, 1]

    def update_q(self, state, action, reward, next_state, t, h):
        # Learning rates
        def zeta1(t):  # fast
            return 1 /(1 + (1 - self.gamma)*t**(0.6)) 

        def zeta2(t):  # medium
            return 1/ (1 + (1 - self.gamma)/10*t**(0.8)) 

        def zeta3(t):  # slow
            return 1/(1 + (1 - self.gamma)/100 * t)
        
        if action == 0:
            # reward /= self.reward_scale
            C = (1 + self.k * (self.k - 1) * self.radius)**(1 / self.k)
            k_star = self.k / (self.k - 1)
            # Unpack values
            z1, z2 = self.Z1[h][state, action], self.Z2[h][state, action]
            eta = self.E[h][state, action]
            z2 = max(z2, 1e-8)  # For computation stability

            next_value = np.max(self.rob_q_value[h+1][next_state])
            self.Z1[h][state, action] += zeta1(t) * (max(eta - next_value, 0)**(k_star - 1) - z1)
            self.Z2[h][state, action] += zeta1(t) * (max(eta - next_value, 0)**k_star - z2)

            # Update beta
            D = 1 - C * z1 * z2**(1 / k_star - 1)
            self.E[h][state, action] += zeta2(t) * D

            # Update Q
            robust_value = eta - C * z2**(1 / k_star)
            self.rob_q_value[h][state, action] += zeta3(t) * (reward + self.gamma * robust_value - self.rob_q_value[h][state, action])
        else:
            self.rob_q_value[h][state, action] = reward 

def async_q_learning(delta, max_iter, k, expl_noise, seed = 0):
    # print('-'*10)
    # print(f'delta: {delta}; renyi_k: {k}; expl_noise: {expl_noise}')
    np.random.seed(seed)
    total_sample_num = 0
    niter = 0
    env = OptionMDP()
    drq = DRQ(reward_scale = 1, state_dim = env.N_S, action_dim = env.N_A, expl_noise = expl_noise, delta = delta, k = k, H = env.H, gamma = env.gamma)
    s, h = env.reset()
    file_name = f"async_q_learning_radius_{delta}_renyik_{k}_seed_{seed}"
    evaluations = list()

    drq.init_q(env)

    # drq.init_double_q(env)

    while True:
        # a = drq.decide(s, h) 
        a = 0
        s_prime, r, terminal, h = env.step(a)
        diff = drq.update_q(s, a, r, s_prime, niter, h-1)
        # diff = drq.update_q_double(s, a, r, s_prime, niter, h-1)
        s = copy.copy(s_prime)
        # print(f"s:{s}, a:{a}")

        if terminal or a == 1:
            s, h = env.reset()
            
        niter += 1
        if niter>max_iter:
            return drq.rob_q_value[0].max(-1)
        
        if niter % 1e3 == 0:
            value_init = np.dot(env.initial_state_dist, drq.rob_q_value[0].max(-1))
            evaluations.append([niter, value_init])
            if args.multiple == False:
                print("number of iterations: ", niter)
                # print("number of samples: ", total_sample_num)
                print("value: ", value_init)
            np.save(f"./results/{file_name}", evaluations)
            np.save(f"./models/{file_name}", drq.rob_q_value)

def async_q_learning_nonrobust(delta, max_iter, k, expl_noise, seed = 0):
    # print('-'*10)
    # print(f'delta: {delta}; renyi_k: {k}; expl_noise: {expl_noise}')
    np.random.seed(seed)
    total_sample_num = 0
    niter = 0
    env = OptionMDP()
    drq = DRQ(reward_scale = 1, state_dim = env.N_S, action_dim = env.N_A, expl_noise = expl_noise, delta = delta, k = k, H = env.H, gamma = env.gamma)
    s, h = env.reset()
    file_name = f"async_q_learning_seed_{seed}"
    evaluations = list()

    drq.init_q(env)

    # drq.init_double_q(env)

    while True:
        # a = drq.decide(s, h) 
        a = 0
        s_prime, r, terminal, h = env.step(a)
        diff = drq.update_q(s, a, r, s_prime, niter, h-1)
        # diff = drq.update_q_double(s, a, r, s_prime, niter, h-1)
        s = copy.copy(s_prime)
        # print(f"s:{s}, a:{a}")

        if terminal or a == 1:
            s, h = env.reset()
            
        niter += 1
        if niter>max_iter:
            return drq.rob_q_value[0].max(-1)
        
        if niter % 1000 == 0:
            value_init = np.dot(env.initial_state_dist, drq.rob_q_value[0].max(-1))
            evaluations.append([niter, value_init])
            if args.multiple == False:
                print("number of iterations: ", niter)
                # print("number of samples: ", total_sample_num)
                print("value: ", value_init)
                print(drq.rob_q_value[0][:10])
            np.save(f"./results/{file_name}", evaluations)
            np.save(f"./models/{file_name}", drq.rob_q_value)

def model_learning(delta, renyi_k, seed = 0):
    file_name = f"modelbase_radius:{delta}_k:{renyi_k}_seed:{seed}"
    evaluations = list()
    env = OptionMDP(radius = delta, renyi_k = renyi_k)
    state_dim = env.N_S
    action_dim = env.N_A
    gamma = env.gamma
    np.random.seed(seed)
    for sample_num in np.arange(5, 1000, 10):
        state = env.reset()
        rs = np.zeros((state_dim, action_dim))
        P_hat = np.zeros((state_dim, action_dim, state_dim))
        Next_V = [np.zeros((state_dim, action_dim)) for _ in range(env.H+1)]
        rob_q_value, temp_rob_q_value, diff = [np.zeros([state_dim, action_dim]) for _ in range(env.H+1)], [np.zeros([state_dim, action_dim]) for _ in range(env.H+1)], [np.zeros([state_dim, action_dim]) for _ in range(env.H+1)]
        niter = 0
        total_sample_num = 0

        rob_q_value[-1][:, 0] = env.r[:, 1]
        rob_q_value[-1][:, 1] = env.r[:, 1]
        start_time = time.time()
        while True:
            alpha = 1 / (1 + (1 - gamma) * niter)
            for h in range(0, env.H):
                for s in range(state_dim):
                    env.set_state(s)
                    # for a in range(action_dim):
                    a = 0
                    # generate sameple number
                    if niter == 0:
                        s_prime, r, done, timeout = env.step_multiple(a, sample_num)
                        rs[s,a] = r[0]
                        total_sample_num += sample_num
                        for i in range(sample_num):
                            P_hat[s, a, s_prime[i]] += 1
                        P_hat[s, a] = P_hat[s, a]/ np.sum(P_hat[s, a])
                    Next_V[h][s,a] = drrenyi_operator(P_hat[s, a], rob_q_value[h+1].max(-1), gamma, delta, renyi_k)
                    temp_rob_q_value[h][s, a] = (1 - alpha) * rob_q_value[h][s, a] + alpha * (rs[s,a] + gamma * Next_V[h][s,a])
                    diff[h][s,a] = abs(temp_rob_q_value[h][s,a] - rob_q_value[h][s,a])
                    rob_q_value[h][s,a] = copy.copy(temp_rob_q_value[h][s, a])

                    a = 1
                    rob_q_value[h][s,a] = env.r[s, a]

            niter += 1

            if np.max(diff)<1e-4 or niter>10:
                break
        value_init = np.dot(env.initial_state_dist, rob_q_value[0].max(-1))
        evaluations.append([total_sample_num, value_init])
        # print(f"sample_num: {total_sample_num}, time: {time.time() - start_time}, value_init: {value_init}")
        # print(f"{total_sample_num}, {value_init}")

        np.save(f"./results/{file_name}", evaluations)
        np.save(f"./models/{file_name}", rob_q_value)

if __name__ == '__main__':
    parser = argparse.ArgumentParser() 
    # parser.add_argument("--seed", default= 0, type = int)
    parser.add_argument("--eval_freq", default= 1000, type = int)
    parser.add_argument("--max_timesteps", default= 5e7, type = int)
    parser.add_argument("--expl_noise", default= 1.0, type = float)
    parser.add_argument("--discount", default= 0.9, type = float)
    # parser.add_argument("--radius", default=0.2, type = float)
    # parser.add_argument("--renyi_k", default=2.0, type = float)
    parser.add_argument("--type", default= "drqm", type = str)
    parser.add_argument("--save_model", default= True, type = bool)
    parser.add_argument("--num_process", default= 60, type = int)
    parser.add_argument("--multiple", default= 'False', type = str)

    args = parser.parse_args()
    if args.multiple == 'False':
        args.multiple = False
    else:
        args.multiple = True

    tolerance = 0.01

    print('-'*10)

    if not os.path.exists("./results"):
        os.makedirs("./results")

    if args.save_model and not os.path.exists("./models"):
        os.makedirs("./models")

    if args.multiple:
        deltas  = [0.1, 0.2, 0.4] # args.radius # 0.05 #1
        renyi_ks = [2.0, 4.0]
        seeds = list(range(10))
        if args.type == 'drqm':
            with Pool(processes=args.num_process) as pool:
                pool.starmap(sync_q_learning, product(deltas, renyi_ks, [args.max_timesteps], seeds))
        elif args.type == 'drq':
            with Pool(processes=args.num_process) as pool:
                pool.starmap(async_q_learning, product(deltas, [args.max_timesteps], renyi_ks, [args.expl_noise], seeds))
            # rob_q_value = async_q_learning(n, host_cost, sell_price, order_cost, delta, gamma, args.max_timesteps, args.renyi_k, args.expl_noise)
        else:
            with Pool(processes=args.num_process) as pool:
                pool.starmap(model_learning, product(deltas, renyi_ks, seeds))
    else:
        delta = 0.05
        renyik = 2.0
        if args.type == 'drqm':
            rob_q_value = sync_q_learning(delta, renyik, args.max_timesteps)
        elif args.type == 'drq':
            rob_q_value = async_q_learning(delta, args.max_timesteps, renyik, args.expl_noise)
        else:
            rob_q_value = model_learning(delta, renyik)
    # for s in range(n+1):
    #     rob_v_value[s]=np.max(rob_q_value[s, : n-s + 1])
    #     greedy_policy[s] = np.argmax(rob_q_value[s, : n-s + 1])
    # print(greedy_policy)
    # print(rob_v_value)