import numpy as np
from env.cliffwalking import CliffWalkingEnv
import argparse
import os
from math import sqrt
import scipy.optimize as opt
from multiprocessing import Pool
from itertools import product

LOG_EPS = 1e-5
BMIN, BMAX = -1/(1-0.9), 1/(1-0.9)


def drrenyi_operator(P, V, discount, radius, renyi_k):
    k = renyi_k
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)

    res = opt.fminbound(
        lambda eta: -eta + C * np.dot(P,
                                      np.clip(eta - V, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return -res[1]

def sync_q_learning(delta, renyi_k, seed):  
    file_name = f"RenyiDROQZ_radius:{delta}_k:{renyi_k}_seed:{seed}"
    evaluations = list()
    optimal_v = np.load(f'./results/Renyi({renyi_k:.1f})_{delta}_optimal.npy')
    for sample_num in np.arange(100, 10000, 100):
        np.random.seed(seed)
        env = CliffWalkingEnv()
        state_dim = env.N_S
        action_dim = env.N_A
        state = env.reset()
        done = False
        gamma = 0.9
        Q = np.random.randn(state_dim, action_dim) / args.reward_scale

        bound = np.inf
        rob_q_value, temp_rob_q_value, diff = np.zeros([state_dim, action_dim]), np.zeros([state_dim, action_dim]), np.zeros([state_dim, action_dim])
        total_sample_num = 0
        niter = 0
        k = 0
        P_hat = np.zeros((state_dim, action_dim, state_dim))
        Next_V = np.zeros((state_dim, action_dim))
        rs = np.zeros((state_dim, action_dim))
        while True:
            alpha = 1 / (1 + (1 - gamma) * niter)
            for s in range(state_dim - 1):
                for a in range(action_dim):
                    # generate sameple number
                    env.set_state(s)
                    if niter == 0:
                        s_prime, r, done, timeout = env.step_multiple(a, sample_num)
                        rs[s,a] = r[0]
                        total_sample_num += sample_num
                        for i in range(sample_num):
                            P_hat[s, a, s_prime[i]] += 1
                        P_hat[s, a] = P_hat[s, a]/ np.sum(P_hat[s, a])
                    Next_V[s,a] = drrenyi_operator(P_hat[s, a], rob_q_value.max(-1), gamma, delta, renyi_k)
                    temp_rob_q_value[s, a] = (1 - alpha) * rob_q_value[s, a] + alpha * (rs[s,a] + gamma * Next_V[s,a])
                    diff[s,a] = abs(temp_rob_q_value[s,a] - rob_q_value[s,a])
                    rob_q_value[s,a] = temp_rob_q_value[s, a]
            niter += 1

            if np.max(diff)<1e-4/5 or niter>1e3:
                break

        opt_diff = np.linalg.norm(optimal_v - rob_q_value.max(-1))
        value_init = np.dot(env.initial_state_dist, rob_q_value.max(-1))
        evaluations.append([total_sample_num, value_init, opt_diff])
        # print(f"Total T: {niter+1} Value: {value_init:.4f} opt_diff : {opt_diff:.4f} Used Samples: {total_sample_num}")
        np.save(f"./results_model/{file_name}", evaluations)
        if args.save_model:
            np.save(f"./models_model/{file_name}", Q)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_timesteps", default=5e6, type=int)  # Max time steps to run environment
    # parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    # parser.add_argument("--radius", default=0.05, type=float)  # Radius of uncentainty set
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--reward_scale", default=1., type=float)  # Normalize the reward    
    parser.add_argument("--num_process", default=1, type=int)  # Normalize the reward    
    args = parser.parse_args()

    if not os.path.exists("./results_qlearning"):
        os.makedirs("./results_qlearning")

    if args.save_model and not os.path.exists("./models_qlearning"):
        os.makedirs("./models_qlearning")

    # num_samples = 1000
    delta = 0.05
    renyi_k = 2
    seed = 0

    # sync_q_learning(delta, renyi_k, seed)

    deltas, seeds, renyi_ks = [0.05, 0.1, 0.2, 0.4], list(range(10)), [1.2, 2.0, 3.0, 4.0]
    with Pool(processes=args.num_process) as pool:
        pool.starmap(sync_q_learning, product(deltas, renyi_ks, seeds))
