import numpy as np
from env.cliffwalking import CliffWalkingEnv
import argparse
import os
from math import sqrt
import scipy.optimize as opt
from multiprocessing import Pool
from itertools import product

LOG_EPS = 1e-5

def drrenyi_operator(s_sample, done, rob_q_value, discount, radius, renyi_k):
    k, n = renyi_k, len(s_sample)
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)
    v_value_rob_sample = np.zeros(n)
    for s in range(n):
        v_value_rob_sample[s] = (1 - done[s]) * np.max(rob_q_value[s_sample[s]])
    res = opt.fminbound(
        lambda eta: -eta + C * np.mean(
                                      np.clip(eta - v_value_rob_sample, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return -res[1]

def drrenyi_operator_r(r, discount, radius, renyi_k):
    k = renyi_k
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)

    res = opt.fminbound(
        lambda eta: -eta + C * np.mean(
                                      np.clip(eta - r, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return  -res[1]


def sync_q_learning(delta, max_timesteps, renyi_k, seed):  
    np.random.seed(seed)
    env = CliffWalkingEnv()
    file_name = f"RenyiDROQZA_radius:{delta}_k:{renyi_k}_seed:{seed}"
    print(f'Version: {file_name}')
    optimal_v = np.load(f'./results/Renyi({renyi_k:.1f})_{delta}_optimal.npy')
    state_dim = env.N_S
    action_dim = env.N_A
    state = env.reset()
    done = False
    gamma, p = 0.9, 0.5
    Q = np.random.randn(state_dim, action_dim) / args.reward_scale
    Z1 = np.zeros((state_dim, action_dim))
    Z2 = np.zeros((state_dim, action_dim))
    E = np.zeros((state_dim, action_dim))

    bound = np.inf
    rob_q_value, temp_rob_q_value, diff = np.zeros([state_dim, action_dim]), np.zeros([state_dim, action_dim]), np.zeros([state_dim, action_dim])
    total_sample_num = 0
    niter = 0
    k = 0
    expl_noise = 1.0
    evaluations = list()
    s = env.reset()
    while True:
        alpha = 1/(1 + (1 - gamma)/20 * (niter+1)) # 1 / (1 + (1 - gamma)/100 * niter)
        # for s in range(state_dim - 1):
        #     for a in range(action_dim):
        #         # generate sameple number
        #         env.set_state(s)
        if np.random.rand() < expl_noise:
            a = np.random.randint(action_dim)
        else:
            a = np.argmax(Q[s])
        N = np.random.geometric(p) - 1
        while N + k > bound:
            N = np.random.geometric(p) - 1
        p_N = p * ((1 - p) ** N)
        N = N + k
        sample_num = 2 ** (N + 1)
        total_sample_num += sample_num

        s_prime, r, done, timeout = env.step_multiple(a, sample_num)
        # r = (1-gamma) * r
        if timeout:
            done = np.ones_like(done)
        base_r = drrenyi_operator_r(r[:2 ** k], gamma, delta, renyi_k)
        base_q = drrenyi_operator(s_prime[:2 ** k], done[:2**k], rob_q_value, gamma, delta, renyi_k)
        delta_r = drrenyi_operator_r(r,  gamma, delta, renyi_k) - 0.5 * drrenyi_operator_r(r[:2 ** N],  gamma, delta, renyi_k) - 0.5 * drrenyi_operator_r(r[2 ** N:],  gamma, delta, renyi_k)
        # delta_r = - delta_r
        delta_q = drrenyi_operator(s_prime,done,rob_q_value, gamma, delta, renyi_k) - 0.5 * drrenyi_operator(s_prime[:2 ** N],done[:2**N],rob_q_value, gamma, delta, renyi_k) - 0.5 * drrenyi_operator(s_prime[2 ** N:],done[:2**N],rob_q_value, gamma, delta, renyi_k)
        # delta_q = -delta_q
        temp_rob_q_value[s, a] = (1 - alpha) * rob_q_value[s, a] + alpha * (base_r + gamma * base_q + (delta_r + gamma * delta_q) / p_N)
        diff[s,a] = abs(temp_rob_q_value[s,a] - rob_q_value[s,a])
        rob_q_value[s,a] = temp_rob_q_value[s, a]

        if done[0] or timeout:
            s = env.reset()
        else:
            s = s_prime[0]
            # env.set_state(s_prime[0])

        niter += 1

        if total_sample_num>= max_timesteps:
            break

        if niter % 1000 == 0:
            value_init = np.dot(env.initial_state_dist, rob_q_value.max(-1))
            opt_diff = np.linalg.norm(optimal_v - rob_q_value.max(-1))
            evaluations.append([total_sample_num, value_init, opt_diff])
            print(f"Total T: {niter+1} Value: {value_init:.4f} Diff: {opt_diff:.4f} Used Samples: {total_sample_num}")
            # print(f"Eta: {np.dot(env.initial_state_dist, E[np.arange(0, env.N_S), Q.argmax(axis=-1)]):.4f}")
            np.save(f"./results_aqlearning/{file_name}", evaluations)
            if args.save_model:
                np.save(f"./models_aqlearning/{file_name}", Q)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_timesteps", default=1e7, type=int)  # Max time steps to run environment
    # parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    # parser.add_argument("--radius", default=0.05, type=float)  # Radius of uncentainty set
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--reward_scale", default=1., type=float)  # Normalize the reward    
    parser.add_argument("--num_process", default=1, type=int)  # Normalize the reward    
    args = parser.parse_args()

    if not os.path.exists("./results_qlearning"):
        os.makedirs("./results_qlearning")

    if args.save_model and not os.path.exists("./models_qlearning"):
        os.makedirs("./models_qlearning")

    max_timesteps = args.max_timesteps

    sync_q_learning(delta = 0.05, max_timesteps=max_timesteps, renyi_k = 2.0, seed = 0)

    # deltas, seeds, renyi_ks = [0.05, 0.1, 0.2, 0.4], list(range(10)), [1.4, 2.0, 3.0, 4.0]
    # with Pool(processes=args.num_process) as pool:
    #     pool.starmap(sync_q_learning, product(deltas, [max_timesteps], renyi_ks, seeds))