import numpy as np
import argparse
import os
from env.cliffwalking import CliffWalkingEnv
from math import log10, log2
from multiprocessing import Pool
from itertools import product

def run(R, seed):
    env = CliffWalkingEnv()
    # optimal_v = np.load(f'./results/Renyi({args.renyi_k:.1f})_{args.radius}_optimal.npy')

    file_name = f'R_{R}_{seed}'

    # Set seeds
    np.random.seed(seed)

    state_dim = env.N_S
    action_dim = env.N_A

    evaluations = []
    discount = args.discount
    # Learning rates
    def zeta1(t):  # fast
        return 1 /(1 + (1 - discount)*t**(0.6))  # 0.03  / (1 + (1 - discount)/100 * t)

    def zeta2(t):  # medium
        return 1/ (1 + (1 - discount)/10*t**(0.8)) # 0.01  / (1 + (1 - discount)/100 * t)

    def zeta3(t):  # slow
        return 1/(1 + (1 - discount)/20 * t) # 0.003 / (1 + (1 - discount)/100 * t) # 0.003 # 

    state = env.reset()
    done = False

    Q = np.random.randn(state_dim, action_dim) / args.reward_scale
    Z1 = np.zeros((state_dim, action_dim))
    Z2 = np.zeros((state_dim, action_dim))
    E = np.zeros((state_dim, action_dim))

    t = 0
    while t < args.max_timesteps:
        t += 1
        if np.random.rand() < args.expl_noise:
            action = np.random.randint(action_dim)
        else:
            action = np.argmax(Q[state])
        next_state, reward, done, timeout = env.step(action)
        # Normalize the reward
        reward /= args.reward_scale
        next_value = np.max(Q[next_state]) if not done else 0

        # Update Q
        robust_value = R * np.min(np.max(Q, axis = -1)) + (1-R) * next_value
        Q[state, action] += zeta3(t) * (reward + discount * robust_value - Q[state, action])

        if done or timeout:
            state = env.reset()
        else:
            state = next_state

        # Re-normalize the reward
        Q_ = Q * args.reward_scale
        # Value of initial state
        value_init = np.dot(env.initial_state_dist, Q_.max(-1))
        # Evaluate episode
        if (t - 1) % args.eval_freq == 0:
            opt_diff = 0 # np.linalg.norm(optimal_v - Q_.max(-1))
            evaluations.append([t, value_init, opt_diff])
            # print(f"Total T: {t+1} Value: {value_init:.4f} Diff: {opt_diff:.4f}")
            # print(f"Eta: {np.dot(env.initial_state_dist, E[np.arange(0, env.N_S), Q_.argmax(axis=-1)]):.4f}")
            np.save(f"./results_R/{file_name}", evaluations)
            if args.save_model:
                np.save(f"./models/{file_name}", Q_)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=1e3, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=3e6, type=int)  # Max time steps to run environment
    parser.add_argument("--expl_noise", default=0.5, type=float)  # Epsilon greedy noise
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    # parser.add_argument("--radius", default=0.05, type=float)  # Radius of uncentainty set
    # parser.add_argument("--type", default="Renyi", type=str)  # Radius of uncentainty set
    # parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    parser.add_argument("--save_model", action="store_true", default= True)  # Save model and optimizer parameters
    parser.add_argument("--reward_scale", default=1., type=float)  # Normalize the reward
    parser.add_argument("--num_process", default=30, type=int)  # Normalize the reward
    args = parser.parse_args()

    if not os.path.exists("./results_R"):
        os.makedirs("./results_R")

    if args.save_model and not os.path.exists("./models"):
        os.makedirs("./models")

    Rs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    seeds = list(range(10))
    with Pool(processes=args.num_process) as pool:
        pool.starmap(run, product(Rs, seeds))