import numpy as np
import argparse
import os
from env.cliffwalking import CliffWalkingEnv

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_episodes", default=100, type=int)  # How many episodes we evaluate
    parser.add_argument("--radius", default=0.1, type=float)  # Radius of uncentainty set
    parser.add_argument("--prob", default=0.25, type=float)  # Radius of uncentainty set
    parser.add_argument("--discount", default=0.9, type=float)  # Radius of uncentainty set
    parser.add_argument("--type", default="Renyi", type=str)  # Radius of uncentainty set
    parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    parser.add_argument("--R", default=0.1, type=float)  # Radius of uncentainty set
    args = parser.parse_args()

    if args.type == "Renyi":
        file_name = f"{args.type}({args.renyi_k})_{args.radius}_{args.prob}"
        print("---------------------------------------")
        print(f"Type: {args.type}({args.renyi_k}), Radius: {args.radius}, Prob: {args.prob}")
        print("---------------------------------------")
    elif args.type == 'R':
        file_name = f"{args.type}_{args.R}_{args.prob}"
        print("---------------------------------------")
        print(f"Type: {args.type}, R: {args.R}, Prob: {args.prob}")
        print("---------------------------------------")
    else:
        file_name = f"{args.type}_{args.radius}_{args.prob}"
        print("---------------------------------------")
        print(f"Type: {args.type}, Radius: {args.radius}, Prob: {args.prob}")
        print("---------------------------------------")

    if not os.path.exists("./evaluations"):
        os.makedirs("./evaluations")

    if args.type == "Renyi":
        # Q = np.mean([np.load(f"./models/{args.type}({args.renyi_k})_{args.radius}_{seed}.npy") for seed in range(10)],
        #             axis=0)
        Q = np.load(f"./models/{args.type}({args.renyi_k})_{args.radius}_optimal_Q.npy") 
        print('Load optimal Q')
    elif args.type == 'nonrobust':
        Q = np.load('./models/nonrobust_optimal_Q.npy')
    elif args.type == 'R':
        Q = np.load(f'./models/R_{args.R}_{args.seed}.npy')
    else:
        Q = np.mean([np.load(f"./models/{args.type}_{args.radius}_{seed}.npy") for seed in range(10)], axis=0)
    print("----------------Policy-----------------")
    print(Q[:-1].argmax(axis=-1).reshape(4, 4))
    # print(Q[12:-1].reshape(4, 4))
    # Set seeds
    np.random.seed(args.seed)

    evaluations = []

    env = CliffWalkingEnv(prob=args.prob)
    total_reward = 0.
    for _ in range(args.eval_episodes):
        state, done, timeout = env.reset(), False, False
        t = 0
        while not done:
            action = np.argmax(Q[state])
            state, reward, done, timeout = env.step(action)
            total_reward += reward
            t += 1
        evaluations.append((total_reward, t))
        total_reward = 0.
    avg_reward = np.mean(np.array(evaluations)[:, 0])
    avg_time = np.mean(np.array(evaluations)[:, 1])
    print("---------------------------------------")
    print(f"Evaluation over {args.eval_episodes} episodes: {avg_reward:.3f} in {avg_time:.3f} steps")
    print("---------------------------------------")

    np.save(f"./evaluations/{file_name}", evaluations)