import numpy as np
import argparse
import os
from env.cliffwalking import CliffWalkingEnv
from math import log10, log2

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=1e3, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e7, type=int)  # Max time steps to run environment
    parser.add_argument("--expl_noise", default=0.5, type=float)  # Epsilon greedy noise
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    parser.add_argument("--radius", default=0.05, type=float)  # Radius of uncentainty set
    parser.add_argument("--type", default="Renyi", type=str)  # Radius of uncentainty set
    parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--reward_scale", default=1., type=float)  # Normalize the reward
    args = parser.parse_args()

    if args.type == "Renyi":
        assert args.renyi_k > 1
        file_name = f"{args.type}({args.renyi_k})_{args.radius}_{args.seed}"
        print("---------------------------------------")
        print(f"Type: {args.type}({args.renyi_k}), Radius: {args.radius}, Seed: {args.seed}")
        print("---------------------------------------")
    else:
        file_name = f"{args.type}_{args.radius}_{args.seed}"
        print("---------------------------------------")
        print(f"Type: {args.type}, Radius: {args.radius}, Seed: {args.seed}")
        print("---------------------------------------")

    if not os.path.exists("./results"):
        os.makedirs("./results")

    if args.save_model and not os.path.exists("./models"):
        os.makedirs("./models")

    env = CliffWalkingEnv()
    optimal_v = np.load(f'./results/Renyi({args.renyi_k:.1f})_{args.radius}_optimal.npy')

    # Set seeds
    np.random.seed(args.seed)

    state_dim = env.N_S
    action_dim = env.N_A

    evaluations = []
    discount = args.discount
    # Learning rates
    def zeta1(t):  # fast
        return 1 /(1 + (1 - discount)*t**(0.6))  # 0.03  / (1 + (1 - discount)/100 * t)

    def zeta2(t):  # medium
        return 1/ (1 + (1 - discount)/10*t**(0.8)) # 0.01  / (1 + (1 - discount)/100 * t)

    def zeta3(t):  # slow
        return 1/(1 + (1 - discount)/20 * t) # 0.003 / (1 + (1 - discount)/100 * t) # 0.003 # 

    state = env.reset()
    done = False

    Q = np.random.randn(state_dim, action_dim) / args.reward_scale
    Z1 = np.zeros((state_dim, action_dim))
    Z2 = np.zeros((state_dim, action_dim))
    E = np.zeros((state_dim, action_dim))

    t = 0
    while t < args.max_timesteps:
        t += 1
        if np.random.rand() < args.expl_noise:
            action = np.random.randint(action_dim)
        else:
            action = np.argmax(Q[state])
        next_state, reward, done, timeout = env.step(action)
        # Normalize the reward
        reward /= args.reward_scale
        next_value = np.max(Q[next_state]) if not done else 0

        if args.type == "Renyi":
            k = args.renyi_k
            C = (1 + k * (k - 1) * args.radius)**(1 / k)
            k_star = k / (k - 1)
            # Unpack values
            z1, z2 = Z1[state, action], Z2[state, action]
            eta = E[state, action]
            z2 = max(z2, 1e-8)  # For computation stability

            Z1[state, action] += zeta1(t) * (max(eta - next_value, 0)**(k_star - 1) - z1)
            Z2[state, action] += zeta1(t) * (max(eta - next_value, 0)**k_star - z2)

            # Update beta
            D = 1 - C * z1 * z2**(1 / k_star - 1)
            E[state, action] += zeta2(t) * D

            # Update Q
            robust_value = eta - C * z2**(1 / k_star)
            Q[state, action] += zeta3(t) * (reward + discount * robust_value - Q[state, action])

        elif args.type == "Chi":
            # Unpack values
            z1, z2 = Z1[state, action], Z2[state, action]
            eta = E[state, action]
            z2 = max(z2, 1e-8)  # For computation stability

            Z1[state, action] += zeta1(t) * (max(eta - next_value, 0) - z1)
            Z2[state, action] += zeta1(t) * (max(eta - next_value, 0)**2 - z2)

            # Update beta
            D = 1 - np.sqrt(1 + 2 * args.radius) * z1 / np.sqrt(z2)
            E[state, action] += zeta2(t) * D
            # Newton method
            # D2 = np.sqrt(args.radius + 1) * (z0 * z2 - z1**2) / z2**(3 / 2)
            # E[state, action] -= zeta2(t) * (D1 / max(D2, 1))

            # Update Q
            robust_value = eta - np.sqrt(args.radius + 1) * np.sqrt(z2)
            Q[state, action] += zeta3(t) * (reward + args.discount * robust_value - Q[state, action])
        else:
            assert NotImplementedError

        if done or timeout:
            state = env.reset()
        else:
            state = next_state

        # Re-normalize the reward
        Q_ = Q * args.reward_scale
        # Value of initial state
        value_init = np.dot(env.initial_state_dist, Q_.max(-1))
        # Evaluate episode
        if (t - 1) % args.eval_freq == 0:
            opt_diff = np.linalg.norm(optimal_v - Q_.max(-1))
            evaluations.append([t, value_init, opt_diff])
            # print(f"Total T: {t+1} Value: {value_init:.4f} Diff: {opt_diff:.4f}")
            # print(f"Eta: {np.dot(env.initial_state_dist, E[np.arange(0, env.N_S), Q_.argmax(axis=-1)]):.4f}")
            np.save(f"./results/{file_name}", evaluations)
            if args.save_model:
                np.save(f"./models/{file_name}", Q_)