import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from test_policy import parse_args, main
from calc_waiting_time import calc_waiting_time

def draw_reward(round_cnt, record_name, memo):
    record_dir = os.path.join(os.getcwd(), "records", memo, record_name, "test_round")

    mean_reward = []
    for i in range(0, round_cnt):
        reward_record = []
        if os.path.getsize(os.path.join(record_dir, "round_{}".format(i), "reward_record.pkl")) == 0:
            print("EMPTY FILE")
        with open(os.path.join(record_dir, "round_{}".format(i), "reward_record.pkl"), "rb") as f:
            reward_record = pickle.load(f)
        mean_reward.append(np.mean(np.array(reward_record)))
    print(np.max(mean_reward))
    print(np.where(mean_reward == np.sort(mean_reward)[-1]))

    save_dir = os.path.join(os.path.join(os.getcwd(), "records", memo, record_name))
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    plt.plot(np.linspace(0, len(mean_reward), len(mean_reward)), mean_reward)
    plt.xlabel("Rounds")
    plt.ylabel("Reward (Pressure)")
    plt.savefig(os.path.join(save_dir, "reward_70.jpg"))
    
def read_reward(round, record_name, feature, memo, test_round):
    record_dir = os.path.join(os.getcwd(), "records", memo, record_name, test_round)

    reward_record = []
    record_file = os.path.join(record_dir, "round_{}".format(round), "reward_record_{}.pkl".format(feature))
    if os.path.getsize(record_file) == 0:
        print("EMPTY FILE")
    with open(record_file, "rb") as f:
        reward_record = pickle.load(f)
    return np.mean(reward_record)
    
def read_travel_time(round, record_name, memo, test_round):
    record_dir = os.path.join(os.getcwd(), "records", memo, record_name, test_round)

    reward_record = []
    record_file = os.path.join(record_dir, "round_{}".format(round), "travel_time.csv")
    df = pd.read_csv(record_file)
    return df["travel time"][0]

def calc_reward_mean_var(round_cnt, record_name, feature, memo, start, size, test_round="test_round"):
    start = start if start > 0 else 0
    record_list = [read_reward(round, record_name, feature, memo, test_round) for round in range(start, start + size)]
    mean = np.mean(record_list)
    std = np.std(record_list)
    print("{} mean: {}".format(feature, mean))
    print("{} lim dev: {}".format(feature, std * 3))
    return mean, std

def calc_waiting_mean_var(round_cnt, record_name, memo, start, size, test_round="test_round"):
    start = start if start > 0 else 0
    waiting_time_list = [calc_waiting_time(round, record_name, memo, test_round) for round in range(start, start + size)]
    mean = np.mean(waiting_time_list)
    std = np.std(waiting_time_list)
    print("waiting time mean: {}".format(mean))
    print("waiting time lim dev: {}".format(std * 3))
    return mean, std

def calc_travel_mean_var(round_cnt, record_name, memo, start, size, test_round="test_round"):
    start = start if start > 0 else 0
    travel_time_list = [read_travel_time(round, record_name, memo, test_round) for round in range(start, start + size)]
    mean = np.mean(travel_time_list)
    std = np.std(travel_time_list)
    print("travel time mean: {}".format(mean))
    print("travel time lim dev: {}".format(std * 3))
    return mean, std

def test_policy_range(record_name, memo, mod, start, size):
    for round in range(start, start + size):
        in_args = parse_args()
        in_args.dataset = record_name.split("_")[3]
        in_args.model_round = round
        in_args.memo = memo
        in_args.record_name = record_name
        in_args.mod = mod
        main(in_args)
        
    
    

# in_args = parse_args()
# rewards = []
# travel_time = []
# for _ in range(10):
#     main(in_args)
#     round = 4
#     file_name = "anon_28_7_newyork_real_double.json_12_24_10_03_34_MOTSC"
#     feature = "lane_num_waiting_vehicle_in"
#     memo = "MOTSC"
#     rewards.append(read_reward(round, file_name, feature, memo))
#     print(rewards)
#     travel_time.append(read_travel_time(round, file_name, memo))
#     print(travel_time)
# draw_reward(75, "anon_28_7_newyork_real_double.json_01_09_17_56_51_MOAMPLIGHT", "MOTSC")
# read_reward(0, "anon_28_7_newyork_real_double.json_01_09_05_25_51_SOTL", "lane_num_waiting_vehicle_in", "benchmark_1001")
round_cnt = 53
record_name = "anon_3_4_jinan_real_2500.json_01_27_08_59_45_MOAMPLIGHT"
memo = "MOTSC"
mod = "AdvancedMPLight"
test_round = "test_round"
# draw_reward(round_cnt, record_name, memo)
# test_policy_range(record_name, memo, mod, round_cnt - 25, 10)
calc_reward_mean_var(round_cnt, record_name, "traffic_movement_pressure_queue_efficient", memo, round_cnt - 25, 10, test_round)
calc_reward_mean_var(round_cnt, record_name, "lane_num_waiting_vehicle_in", memo, round_cnt - 25, 10, test_round)
calc_waiting_mean_var(round_cnt, record_name, memo, round_cnt - 25, 10, test_round)
calc_travel_mean_var(round_cnt, record_name, memo, round_cnt - 25, 10, test_round)