import json
import os
import pandas as pd
from time import *
import torch
import random
import numpy as np
import argparse

from cityflow_env import CityFlowEnvM
from utility import parse_roadnet
from ppo_agent import *

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

seed=2024
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

pruning_ratio=[0,0,0]

class PPOConfig:
    def __init__(self) -> None:
        self.batch_size = 5
        self.gamma=0.99
        self.n_epochs = 4
        self.actor_lr = 0.0005
        self.critic_lr = 0.001
        self.gae_lambda=0.95
        self.policy_clip=0.2
        self.hidden_dim = 32
        self.update_fre = 30
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

def main():
    parser=argparse.ArgumentParser()
    parser.add_argument('--d',type=str,default='Syn1',help='dataset')
    args=parser.parse_args()

    print(str(args.d))

    dataset_path="Datasets/"+str(args.d)+"/"

    dataset=''
    if(str(args.d)=='Jinan1' or str(args.d)=='Jinan2' or str(args.d)=='Jinan3'):
        dataset = "jinan"
    if(str(args.d)=='Hangzhou1' or str(args.d)=='Hangzhou2'):
        dataset = "hangzhou"

    cityflow_config = {
        "interval": 1,
        "seed": 0,
        "laneChange": False,
        "dir": dataset_path,
        "roadnetFile": "roadnet.json",
        "flowFile": "flow.json",
        "rlTrafficLight": True,
        "saveReplay": False,
        "roadnetLogFile": "replayRoadNet.json",
        "replayLogFile": "replayLogFile.txt"
    }

    with open(os.path.join(dataset_path, "cityflow.config"), "w") as json_file:
        json.dump(cityflow_config, json_file)

    config = {
        'cityflow_config_file': dataset_path+"cityflow.config",
        'epoch': 200,
        'num_step': 3600,
        'save_freq': 1,
        'phase_step': 10,
        'model': 'PPO',
    }

    cfg=PPOConfig()

    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)

    intersection_id = list(config['lane_phase_info'].keys())
    config["intersection_id"] = intersection_id
    phase_list = {id_: config["lane_phase_info"][id_]["phase"] for id_ in intersection_id}
    config["phase_list"] = phase_list

    if not os.path.exists('Results'):
        os.makedirs("Results")
    result_dir = "Results/{}".format(str(args.d))

    env = CityFlowEnvM(config["lane_phase_info"],
                       intersection_id,
                       num_step=config["num_step"],
                       thread_num=8,
                       cityflow_config_file=config["cityflow_config_file"],
                       dataset=dataset
                       )

    config["state_size"] = env.state_size

    Magent = MPPOAgent(intersection_id,
                       state_size=config["state_size"],
                       cfg=cfg,
                       phase_list=config["phase_list"]
                       )
    
    Magent.random_initialize(pruning_ratio)

    EPISODES = config['epoch']
    total_step = 0
    episode_rewards = {id_: [] for id_ in intersection_id}
    episode_travel_time = []

    i=0
    while i<EPISODES:
        env.reset()
        state = {}
        for id_ in intersection_id:
            state[id_] = env.get_state_(id_)


        episode_length = 0
        episode_reward = {id_: 0 for id_ in intersection_id} 
        while episode_length < config['num_step']:
            action,prob,val = Magent.choose_action(state)

            labels = {id_: [] for id_ in intersection_id}
            for id_ in intersection_id:
                current_state=np.squeeze(state[id_])
                pressure=[]
                for j in range(4):
                    pressure.append(current_state[j])
                labels[id_]=np.argmax(pressure)

            action_phase = {}
            for id_, a in action.items():
                action_phase[id_] = phase_list[id_][a]

            next_state, reward = env.step(action_phase,cur_step=episode_length) 

            for _ in range(config['phase_step'] - 1):
                next_state, reward_ = env.step(action_phase,cur_step=episode_length)

                episode_length += 1
                total_step += 1
                for id_ in intersection_id:
                    reward[id_] += reward_[id_]

            for id_ in intersection_id:
                episode_reward[id_] += reward[id_]

            episode_length += 1
            total_step += 1

            done = {}
            if episode_length==3600:
                done={id_: 1 for id_ in intersection_id}
            else:
                done={id_: 0 for id_ in intersection_id}
                
            Magent.remember(state, action, prob, val, reward_, done,labels)
            if total_step % cfg.update_fre == 0:
                Magent.replay(i)

            state = next_state

        print('Episode: {},travel time: {}'.format(i, env.eng.get_average_travel_time()))
        episode_travel_time.append(env.eng.get_average_travel_time())
        for id_ in intersection_id:
            episode_rewards[id_].append(episode_reward[id_])
        i+=1

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    df = pd.DataFrame(episode_rewards)
    df.to_csv(result_dir + '/rewards.csv', index=False)

    df = pd.DataFrame({"travel time": episode_travel_time})
    df.to_csv(result_dir + '/travel time.csv', index=False)

if __name__ == '__main__':
    main()