import numpy as np
from env.cliffwalking import CliffWalkingEnv
import argparse
import os
from math import sqrt
import scipy.optimize as opt
from multiprocessing import Pool
from itertools import product

LOG_EPS = 1e-5
BMIN, BMAX = -1/(1-0.9), 1/(1-0.9)

# def tri_search(s_sample, done, upper, rob_q_value, delta):
#     v_value_rob_sample = np.zeros(len(s_sample))
#     for s in range(len(s_sample)):
#         v_value_rob_sample[s] = (1 - done[s]) * np.max(rob_q_value[s_sample[s]])
#     min_v_value_rob_sample = np.min(rob_q_value)
#     max_v_value_rob_sample = np.max(rob_q_value)
#     diff_v_value_rob_sample = -v_value_rob_sample + min_v_value_rob_sample
#     b_max = (max_v_value_rob_sample - min_v_value_rob_sample) / delta
#     b_min = 1e-5
#     f1 = 0
#     f2 = 1
#     # while (b_min + 1e-3 < b_max):
#     num_iter = 0
#     while (np.abs(f1 - f2) > 1e-5) and (num_iter < 1e+3):
#         b1 = (b_min + b_max) / 2
#         b2 = (b1 + b_max) / 2
#         total1 = np.mean(np.exp(diff_v_value_rob_sample / b1))
#         total2 = np.mean(np.exp(diff_v_value_rob_sample / b2))
#         f1 = b1 * np.log(total1 + LOG_EPS) + b1 * delta - min_v_value_rob_sample
#         f2 = b2 * np.log(total2 + LOG_EPS) + b2 * delta - min_v_value_rob_sample
#         if f1 < f2:
#             b_max = b2
#         else:
#             b_min = b1
#         num_iter += num_iter
#     return f1

###################################################
# def tri_search(s_sample, done, upper, rob_q_value, delta):
#     v_value_rob_sample = np.zeros(len(s_sample))
#     for s in range(len(s_sample)):
#         v_value_rob_sample[s] = (1 - done[s]) * np.max(rob_q_value[s_sample[s]])
#     min_v_value_rob_sample = np.min(rob_q_value)
#     max_v_value_rob_sample = np.max(rob_q_value)
#     diff_v_value_rob_sample = -v_value_rob_sample + min_v_value_rob_sample
#     b_max = BMAX # (max_v_value_rob_sample - min_v_value_rob_sample) / delta
#     b_min = BMIN
#     f1 = 0
#     f2 = 1
#     # while (b_min + 1e-3 < b_max):
#     num_iter = 0
#     while (np.abs(f1 - f2) > 1e-5) and (num_iter < 1e+3):
#         b1 = (b_min + b_max) / 2
#         b2 = (b1 + b_max) / 2
#         # total1 = np.mean(np.exp(diff_v_value_rob_sample / b1))
#         # total2 = np.mean(np.exp(diff_v_value_rob_sample / b2))
#         # f1 = b1 * np.log(total1 + LOG_EPS) + b1 * delta - min_v_value_rob_sample
#         # f2 = b2 * np.log(total2 + LOG_EPS) + b2 * delta - min_v_value_rob_sample

#         # chi-square
#         f1 = b1 - sqrt(1+delta) * sqrt(np.mean(np.power(np.maximum(b1 - diff_v_value_rob_sample, 0), 2))) - min_v_value_rob_sample
#         f2 = b2 - sqrt(1+delta) * sqrt(np.mean(np.power(np.maximum(b2 - diff_v_value_rob_sample, 0), 2))) - min_v_value_rob_sample
#         if f1 > f2:
#             b_max = b2
#         else:
#             b_min = b1
#         num_iter += num_iter
#     return f1

# def tri_search_r(r_sample, R, delta):
#     min_r_sample = np.min(r_sample)
#     diff_r_sample = -r_sample + min_r_sample
#     b_max = BMAX #R / delta
#     b_min = BMIN #1e-5
#     f1 = 0
#     f2 = 1
#     # while (b_min + 1e-3 < b_max):
#     num_iter = 0
#     while (np.abs(f1 - f2) > 1e-5) and (num_iter < 1e+3):
#         b1 = (b_min + b_max) / 2
#         b2 = (b1 + b_max) / 2
#         # total1 = np.mean(np.exp(diff_r_sample/ b1))
#         # total2 = np.mean(np.exp(diff_r_sample / b2))
#         # f1 = b1 * np.log(total1 + LOG_EPS) + b1 * delta - min_r_sample
#         # f2 = b2 * np.log(total2 + LOG_EPS) + b2 * delta - min_r_sample


#         # chi-square
#         f1 = b1 - sqrt(1+delta) * sqrt(np.mean(np.power(np.maximum(b1 - diff_r_sample, 0), 2))) - min_r_sample
#         f2 = b2 - sqrt(1+delta) * sqrt(np.mean(np.power(np.maximum(b2 - diff_r_sample, 0), 2))) - min_r_sample

#         # 
#         if f1 > f2:
#             b_max = b2
#         else:
#             b_min = b1
#         num_iter += num_iter
#     return f1
###################################################


# def tri_search_r(r_sample, R, delta):
#     min_r_sample = np.min(r_sample)
#     diff_r_sample = -r_sample + min_r_sample
#     b_max = R / delta
#     b_min = 1e-5
#     f1 = 0
#     f2 = 1
#     # while (b_min + 1e-3 < b_max):
#     num_iter = 0
#     while (np.abs(f1 - f2) > 1e-5) and (num_iter < 1e+3):
#         b1 = (b_min + b_max) / 2
#         b2 = (b1 + b_max) / 2
#         total1 = np.mean(np.exp(diff_r_sample/ b1))
#         total2 = np.mean(np.exp(diff_r_sample / b2))
#         f1 = b1 * np.log(total1 + LOG_EPS) + b1 * delta - min_r_sample
#         f2 = b2 * np.log(total2 + LOG_EPS) + b2 * delta - min_r_sample
#         if f1 < f2:
#             b_max = b2
#         else:
#             b_min = b1
#         num_iter += num_iter
#     return f1


def drrenyi_operator(s_sample, done, rob_q_value, discount, radius, renyi_k):
    k, n = renyi_k, len(s_sample)
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)
    v_value_rob_sample = np.zeros(n)
    for s in range(n):
        v_value_rob_sample[s] = (1 - done[s]) * np.max(rob_q_value[s_sample[s]])
    res = opt.fminbound(
        lambda eta: -eta + C * np.mean(
                                      np.clip(eta - v_value_rob_sample, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return -res[1]

def drrenyi_operator_r(r, discount, radius, renyi_k):
    k = renyi_k
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)

    res = opt.fminbound(
        lambda eta: -eta + C * np.mean(
                                      np.clip(eta - r, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return  -res[1]


def sync_q_learning(delta, max_timesteps, renyi_k, seed):  
    np.random.seed(seed)
    env = CliffWalkingEnv()
    file_name = f"RenyiDROQZ_radius:{delta}_k:{renyi_k}_seed:{seed}"
    optimal_v = np.load(f'./results/Renyi({renyi_k:.1f})_{delta}_optimal.npy')
    state_dim = env.N_S
    action_dim = env.N_A
    state = env.reset()
    done = False
    gamma, p = 0.9, 0.6
    Q = np.random.randn(state_dim, action_dim) / args.reward_scale
    Z1 = np.zeros((state_dim, action_dim))
    Z2 = np.zeros((state_dim, action_dim))
    E = np.zeros((state_dim, action_dim))

    bound = np.inf
    rob_q_value, temp_rob_q_value, diff = np.zeros([state_dim, action_dim]), np.zeros([state_dim, action_dim]), np.zeros([state_dim, action_dim])
    total_sample_num = 0
    niter = 0
    k = 0
    evaluations = list()
    while True:
        alpha = 1/(1 + (1 - gamma)/20 * niter) # 1 / (1 + (1 - gamma) * niter)
        for s in range(state_dim - 1):
            for a in range(action_dim):
                # generate sameple number
                env.set_state(s)
                N = np.random.geometric(p) - 1
                while N + k > bound:
                    N = np.random.geometric(p) - 1
                p_N = p * ((1 - p) ** N)
                N = N + k
                sample_num = 2 ** (N + 1)
                total_sample_num += sample_num

                s_prime, r, done, timeout = env.step_multiple(a, sample_num)
                # r = (1-gamma) * r
                base_r = drrenyi_operator_r(r[:2 ** k], gamma, delta, renyi_k)
                base_q = drrenyi_operator(s_prime[:2 ** k], done[:2**k], rob_q_value, gamma, delta, renyi_k)
                delta_r = drrenyi_operator_r(r,  gamma, delta, renyi_k) - 0.5 * drrenyi_operator_r(r[:2 ** N],  gamma, delta, renyi_k) - 0.5 * drrenyi_operator_r(r[2 ** N:],  gamma, delta, renyi_k)
                # delta_r = - delta_r
                delta_q = drrenyi_operator(s_prime,done,rob_q_value, gamma, delta, renyi_k) - 0.5 * drrenyi_operator(s_prime[:2 ** N],done[:2**N],rob_q_value, gamma, delta, renyi_k) - 0.5 * drrenyi_operator(s_prime[2 ** N:],done[:2**N],rob_q_value, gamma, delta, renyi_k)
                # delta_q = -delta_q
                temp_rob_q_value[s, a] = (1 - alpha) * rob_q_value[s, a] + alpha * (base_r + gamma * base_q + (delta_r + gamma * delta_q) / p_N)
                diff[s,a] = abs(temp_rob_q_value[s,a] - rob_q_value[s,a])
                rob_q_value[s,a] = temp_rob_q_value[s, a]
        niter += 1

        if total_sample_num>= max_timesteps:
            break

        if niter % 1 == 0:
            value_init = np.dot(env.initial_state_dist, rob_q_value.max(-1))
            opt_diff = np.linalg.norm(optimal_v - rob_q_value.max(-1))
            evaluations.append([total_sample_num, value_init, opt_diff])
            # print(f"Total T: {niter+1} Value: {value_init:.4f} Diff: {opt_diff:.4f} Used Samples: {total_sample_num}")
            # print(f"Eta: {np.dot(env.initial_state_dist, E[np.arange(0, env.N_S), Q.argmax(axis=-1)]):.4f}")

            np.save(f"./results_qlearning/{file_name}", evaluations)
            if args.save_model:
                np.save(f"./models_qlearning/{file_name}", Q)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_timesteps", default=1e7, type=int)  # Max time steps to run environment
    # parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    # parser.add_argument("--radius", default=0.05, type=float)  # Radius of uncentainty set
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--reward_scale", default=1., type=float)  # Normalize the reward    
    parser.add_argument("--num_process", default=1, type=int)  # Normalize the reward    
    args = parser.parse_args()

    if not os.path.exists("./results_qlearning"):
        os.makedirs("./results_qlearning")

    if args.save_model and not os.path.exists("./models_qlearning"):
        os.makedirs("./models_qlearning")

    max_timesteps = args.max_timesteps

    # sync_q_learning(delta = 0.05, max_timesteps=max_timesteps, renyi_k = 2, seed = 0)

    deltas, seeds, renyi_ks = [0.05, 0.1, 0.2, 0.4, 1.0], list(range(10)), [1.4, 1.5, 2.0, 3.0, 4.0]
    with Pool(processes=args.num_process) as pool:
        pool.starmap(sync_q_learning, product(deltas, [max_timesteps], renyi_ks, seeds))