import pyrallis
import os
import torch
import numpy as np
import time
import pickle
import argparse
import pandas as pd
from tqdm import trange
from offline_method.edac import TrainConfig, set_seed, wrap_env, modify_reward, ReplayBuffer, Actor, VectorizedCritic, EDAC, eval_actor, make_d4rl_dataset
from utils.mopo.mopo_buffer import mopo_buffer
from utils.cityflow import CityFlowEnv
from utils.travel_time import travel_time_metric
from utils.pipeline import copy_cityflow_file, copy_conf_file, path_check
from utils.utils import merge
from utils import config

class CityFlowWrapper():
    def __init__(self, path_to_work_directory, dic_traffic_env_conf):
        self.env = CityFlowEnv(path_to_work_directory, dic_traffic_env_conf)
        self.logger = edac_logger(path_to_work_directory)
        self.metric = travel_time_metric(self.env)
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.run_counts = self.dic_traffic_env_conf["RUN_COUNTS"]
        
    def seed(self, seed):
        pass
    
    def _format_state_output(self, state):
        outputs = []
        for one_state in state:
            output = []
            for key in one_state.keys():
                output.extend(one_state[key])
            outputs.append(output)
            
        return np.array(outputs)
    
    def update_epoch(self, epoch):
        self.logger.update_epoch(epoch)
    
    def reset(self):
        state = self.env.reset()
        return self._format_state_output(state)
    
    def step(self, action):
        action = np.reshape(action, (self.dic_traffic_env_conf["NUM_INTERSECTIONS"], ))
        formated_action = []
        for one_action in action:
            if int(one_action) > 3 or int(one_action) < 0:
                formated_action.append(0)
            else:
                formated_action.append(int(one_action))
        next_state, reward = self.env.step(np.array(formated_action))
        done = self.env.get_current_time() >= self.run_counts
        info = {}
        self.logger.log_reward("lane_num_waiting_vehicle_in", next_state)
        self.logger.log_reward("traffic_movement_pressure_queue_efficient", next_state)
        self.logger.log_reward("real_reward", reward)
        self.metric.update(done)
        
        if done:
            for inter_ind in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"]):
                self.logger.log_vehicle(self.env.get_dic_vehicle_arrive_leave_time(inter_ind))
            self.logger.dump_reward()
            self.logger.dump_vehicle()
            self.metric.log_travel_time(self.logger.get_log_path())
            
        return self._format_state_output(next_state), np.mean(reward), done, info
            
class edac_logger():
    def __init__(self, path_to_work_directory):
        self.path_to_work_direction = path_to_work_directory
        self.rewards = {"real_reward": []}
        self.vehicle_all_inters = []
        self.epoch = 0
    
    def update_epoch(self, epoch):
        self.epoch = epoch
        
    def log_reward(self, feature, data):
        if feature == "real_reward":
            self.rewards[feature].append(data)
        else:
            if feature not in self.rewards.keys():
                self.rewards[feature] = []
            reward = [np.sum(one_data[feature]) * -0.25 for one_data in data]
            self.rewards[feature].append(reward)
        
    def dump_vehicle(self):
        dump_path = os.path.join(self.path_to_work_direction, "round_{}".format(self.epoch))
        for inter_ind, vehicle_one_inter in enumerate(self.vehicle_all_inters):
            path_to_log_file = os.path.join(dump_path, "vehicle_inter_{0}.csv".format(inter_ind))
            df = pd.DataFrame.from_dict(vehicle_one_inter, orient="index")
            df.to_csv(path_to_log_file, na_rep="nan")
        
        self.vehicle_all_inters.clear()
        
    def log_vehicle(self, dic_vehicle):
        self.vehicle_all_inters.append(dic_vehicle)
        
    def dump_reward(self):
        dump_path = os.path.join(self.path_to_work_direction, "round_{}".format(self.epoch))
        if not os.path.exists(dump_path):
            os.makedirs(dump_path)
            
        for key in self.rewards.keys():
            with open(os.path.join(dump_path, "reward_record_{}.pkl".format(key)), "wb") as f:
                pickle.dump(self.rewards[key], f)
                
        self.rewards = {"real_reward": []}
        
    def get_log_path(self):
        return os.path.join(self.path_to_work_direction, "round_{}".format(self.epoch))

def train(config: TrainConfig, dic_traffic_env_conf : dict, dic_path : dict):
    set_seed(config.train_seed, deterministic_torch=config.deterministic_torch)

    # data, evaluation, env setup
    eval_env = CityFlowWrapper(config.checkpoints_path, dic_traffic_env_conf)
    files = os.listdir(dic_path["PATH_TO_OFFLINE_DATA"])
    dataset = []
    for file in files:
        with open(os.path.join(dic_path["PATH_TO_OFFLINE_DATA"], file), "rb") as f:
            dataset.append(pickle.load(f))
            
    
    d4rl_dataset = make_d4rl_dataset(dataset, dic_traffic_env_conf)
    
    state_dim = d4rl_dataset["observations"].shape[-1]
    action_dim = d4rl_dataset["actions"].shape[-1]

    buffer = ReplayBuffer(
        state_dim=state_dim,
        action_dim=action_dim,
        buffer_size=config.buffer_size,
        device=config.device,
    )
    buffer.load_d4rl_dataset(d4rl_dataset)

    # Actor & Critic setup
    actor = Actor(state_dim, action_dim, config.hidden_dim, config.max_action)
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate)
    critic = VectorizedCritic(
        state_dim, action_dim, config.hidden_dim, config.num_critics
    )
    critic.to(config.device)
    critic_optimizer = torch.optim.Adam(
        critic.parameters(), lr=config.critic_learning_rate
    )

    trainer = EDAC(
        actor=actor,
        actor_optimizer=actor_optimizer,
        critic=critic,
        critic_optimizer=critic_optimizer,
        gamma=config.gamma,
        tau=config.tau,
        eta=config.eta,
        alpha_learning_rate=config.alpha_learning_rate,
        device=config.device,
    )
    # saving config to the checkpoint
    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    total_updates = 0.0
    for epoch in trange(config.num_epochs, desc="Training"):
        eval_env.update_epoch(epoch)
        # training
        for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False):
            batch = buffer.sample(config.batch_size)
            update_info = trainer.update(batch)


            total_updates += 1

        # evaluation
        if epoch % config.eval_every == 0 or epoch == config.num_epochs - 1:
            eval_returns = eval_actor(
                env=eval_env,
                actor=actor,
                n_episodes=config.eval_episodes,
                seed=config.eval_seed,
                device=config.device,
            )
            eval_log = {
                "eval/reward_mean": np.mean(eval_returns),
                "eval/reward_std": np.std(eval_returns),
                "epoch": epoch,
            }
            if hasattr(eval_env, "get_normalized_score"):
                normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0
                eval_log["eval/normalized_score_mean"] = np.mean(normalized_score)
                eval_log["eval/normalized_score_std"] = np.std(normalized_score)

            if config.checkpoints_path is not None:
                save_path = os.path.join(dic_path["PATH_TO_MODEL"], "round_{}".format(epoch))
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                torch.save(
                    trainer.state_dict(),
                    os.path.join(save_path, f"{epoch}.pt"),
                )

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-memo",       type=str,           default='offline')
    parser.add_argument("-eightphase",  action="store_true", default=False)
    parser.add_argument("-gen",        type=int,            default=1)
    parser.add_argument("-multi_process", action="store_true", default=True)
    parser.add_argument("-workers",    type=int,            default=3)
    parser.add_argument("-dataset",     type=str,   choices=['jinan', 'hangzhou', 'newyork'],    default='newyork')
    parser.add_argument("-offline", type=str, choices=["Fixedtime", "MaxPressure"], default="Fixedtime")
    parser.add_argument("-data_size", type=str, choices=["20", "40", "60", "80", "100"], default="")
    return parser.parse_args()

def edac_main(in_args):
    edac_config = TrainConfig()
    if in_args.dataset == 'hangzhou':
        road_net = "4_4"
        edac_config.traffic_file = "anon_4_4_hangzhou_real_5816.json"
        template = "Hangzhou"
    elif in_args.dataset == 'jinan':
        road_net = "3_4"
        edac_config.traffic_file = "anon_3_4_jinan_real_2500.json"
        template = "Jinan"
    elif in_args.dataset == 'newyork':
        road_net = "28_7"
        edac_config.traffic_file = "anon_28_7_newyork_real_double.json"
        template = "NewYork"
    memo = in_args.memo
    
    NUM_COL = int(road_net.split('_')[1])
    NUM_ROW = int(road_net.split('_')[0])
    num_intersections = NUM_ROW * NUM_COL
    edac_config.checkpoints_path = os.path.join(os.getcwd(), "records", "offline", edac_config.traffic_file + "_" + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_EDAC")
    edac_config.traffic_city = template.lower()
    
    dic_traffic_env_conf_extra = {
        "NUM_INTERSECTIONS": num_intersections,
        "RUN_COUNTS": edac_config.eval_steps,
        "NUM_ROW": NUM_ROW,
        "NUM_COL": NUM_COL,
        "TRAFFIC_FILE": edac_config.traffic_file,
        "ROADNET_FILE": "roadnet_{0}.json".format(road_net),
        "TRAFFIC_SEPARATE": edac_config.traffic_file,
        "OFFLINE_DATA_SIZE": in_args.data_size,
        "LIST_STATE_FEATURE": [
            "lane_num_waiting_vehicle_in",
            "lane_enter_running_part",
            "traffic_movement_pressure_queue_efficient"
        ],
        "DIC_REWARD_INFO": {
            "traffic_movement_pressure_queue_efficient": -0.25
        },
        "OFFLINE_METHOD": 'EDAC',
        "LIST_INFO_FEATURE": [
            "lane_num_waiting_vehicle_in",
        ]
    }
    
    dic_path_extra = {
        "PATH_TO_MODEL": os.path.join("model", memo, edac_config.traffic_file + "_"
                                      + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_EDAC"),
        "PATH_TO_WORK_DIRECTORY": os.path.join("records", memo, edac_config.traffic_file + "_"
                                               + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_EDAC"),
        "PATH_TO_DATA": os.path.join("data", template, str(road_net)),
        "PATH_TO_ERROR": os.path.join("errors", memo),
        "PATH_TO_OFFLINE_DATA": os.path.join("offline_dataset", "Fixedtime", edac_config.traffic_file, in_args.data_size, "data")
    }
    dic_agent_conf = getattr(config, "DIC_BASE_AGENT_CONF")
    dic_traffic_env_conf = merge(config.dic_traffic_env_conf, dic_traffic_env_conf_extra)
    dic_path = merge(config.DIC_PATH, dic_path_extra)
        
    path_check(dic_path)
    copy_conf_file(dic_path, dic_agent_conf, dic_traffic_env_conf)
    copy_cityflow_file(dic_path, dic_traffic_env_conf)
    
    train(config=edac_config, dic_traffic_env_conf=dic_traffic_env_conf, dic_path=dic_path)
    
if __name__ == "__main__":
    in_args = parse_args()
    edac_main(in_args)