import numpy as np
import argparse
import os
from env.cliffwalking import CliffWalkingEnv

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=1e5, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e7, type=int)  # Max time steps to run environment
    parser.add_argument("--expl_noise", default=0.2, type=float)  # Epsilon greedy noise
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    parser.add_argument("--save_model", action="store_true")  # Save model and optimizer parameters
    parser.add_argument("--reward_scale", default=1., type=float)  # Normalize the reward
    args = parser.parse_args()

    if not os.path.exists("./results_nonrobust"):
        os.makedirs("./results_nonrobust")

    if args.save_model and not os.path.exists("./models_nonrobust"):
        os.makedirs("./models_nonrobust")

    file_name = f"nonrobust_seed:{args.seed}"
    env = CliffWalkingEnv()

    # Set seeds
    np.random.seed(args.seed)

    state_dim = env.N_S
    action_dim = env.N_A

    evaluations = []


    state = env.reset()
    done = False
    discount = args.discount

    Q = np.random.randn(state_dim, action_dim) 

    t = 0
    while t < args.max_timesteps:
        t += 1
        # alpha = 1 / (1 + (1 - discount) * t)
        alpha = 0.95
        if np.random.rand() < args.expl_noise:
            action = np.random.randint(action_dim)
        else:
            action = np.argmax(Q[state])
        next_state, reward, done, timeout = env.step(action)
        next_value = np.max(Q[next_state]) if not done and not timeout else 0
        Q[state][action] = (1-alpha) * Q[state][action] + alpha * (reward + discount * next_value)

        if done or timeout:
            state = env.reset()
        else:
            state = next_state
        # Value of initial state
        value_init = np.dot(env.initial_state_dist, Q.max(-1))
        # Evaluate episode
        if (t - 1) % args.eval_freq == 0:
            evaluations.append(value_init)
            print(f"Total T: {t+1} Value: {value_init:.4f}")
            np.save(f"./results_nonrobust/{file_name}", evaluations)
            if args.save_model:
                np.save(f"./models_nonrobust/{file_name}", Q)