import pyrallis
import os
import torch
import numpy as np
import time
import pickle
import argparse
from pathlib import Path
from offline_method.iql import TrainConfig, set_seed, ReplayBuffer, TwinQ, ValueFunction, DeterministicPolicy, GaussianPolicy, ImplicitQLearning, eval_actor
from run_edac import CityFlowWrapper
from offline_method.edac import make_d4rl_dataset
from utils.pipeline import copy_cityflow_file, copy_conf_file, path_check
from utils.utils import merge
from utils import config

def train(config: TrainConfig, dic_traffic_env_conf : dict, dic_path : dict):
    env = CityFlowWrapper(config.checkpoints_path, dic_traffic_env_conf)
    files = os.listdir(dic_path["PATH_TO_OFFLINE_DATA"])
    dataset = []
    for file in files:
        with open(os.path.join(dic_path["PATH_TO_OFFLINE_DATA"], file), "rb") as f:
            dataset.append(pickle.load(f))
            
    d4rl_dataset = make_d4rl_dataset(dataset, dic_traffic_env_conf)
    
    state_dim = d4rl_dataset["observations"].shape[-1]
    action_dim = d4rl_dataset["actions"].shape[-1]

    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(d4rl_dataset)

    max_action = config.max_action

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed)

    q_network = TwinQ(state_dim, action_dim).to(config.device)
    v_network = ValueFunction(state_dim).to(config.device)
    actor = (
        DeterministicPolicy(state_dim, action_dim, max_action)
        if config.iql_deterministic
        else GaussianPolicy(state_dim, action_dim, max_action)
    ).to(config.device)
    v_optimizer = torch.optim.Adam(v_network.parameters(), lr=3e-4)
    q_optimizer = torch.optim.Adam(q_network.parameters(), lr=3e-4)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "q_network": q_network,
        "q_optimizer": q_optimizer,
        "v_network": v_network,
        "v_optimizer": v_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # IQL
        "beta": config.beta,
        "iql_tau": config.iql_tau,
        "max_steps": config.num_epochs,
    }

    print("---------------------------------------")
    print(f"Training IQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ImplicitQLearning(**kwargs)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(policy_file))
        actor = trainer.actor

    evaluations = []
    for t in range(int(config.num_epochs)):
        env.update_epoch(t)
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        # Evaluate episode
        if (t + 1) % config.eval_every == 0:
            print(f"Time steps: {t + 1}")
            eval_scores = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.eval_episodes,
                seed=config.seed,
            )
            save_path = os.path.join(dic_path["PATH_TO_MODEL"], "round_{}".format(t))
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            torch.save(
                trainer.state_dict(),
                os.path.join(save_path, f"checkpoint_{t}.pt"),
            )

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-memo",       type=str,           default='offline')
    parser.add_argument("-eightphase",  action="store_true", default=False)
    parser.add_argument("-gen",        type=int,            default=1)
    parser.add_argument("-multi_process", action="store_true", default=True)
    parser.add_argument("-workers",    type=int,            default=3)
    parser.add_argument("-dataset",     type=str,   choices=['jinan', 'hangzhou', 'newyork'],    default='newyork')
    parser.add_argument("-offline", type=str, choices=["Fixedtime", "MaxPressure"], default="Fixedtime")
    parser.add_argument("-data_size", type=str, choices=["20", "40", "60", "80", "100"], default="")
    return parser.parse_args()

def iql_main(in_args):
    iql_config = TrainConfig()
    if in_args.dataset == 'hangzhou':
        road_net = "4_4"
        iql_config.traffic_file = "anon_4_4_hangzhou_real_5816.json"
        template = "Hangzhou"
    elif in_args.dataset == 'jinan':
        road_net = "3_4"
        iql_config.traffic_file = "anon_3_4_jinan_real_2500.json"
        template = "Jinan"
    elif in_args.dataset == 'newyork':
        road_net = "28_7"
        iql_config.traffic_file = "anon_28_7_newyork_real_double.json"
        template = "NewYork"
    memo = in_args.memo
    
    NUM_COL = int(road_net.split('_')[1])
    NUM_ROW = int(road_net.split('_')[0])
    num_intersections = NUM_ROW * NUM_COL
    iql_config.checkpoints_path = os.path.join(os.getcwd(), "records", memo, iql_config.traffic_file + "_" + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_IQL")
    iql_config.traffic_city = template.lower()
    
    dic_traffic_env_conf_extra = {
        "NUM_INTERSECTIONS": num_intersections,
        "RUN_COUNTS": iql_config.eval_steps,
        "NUM_ROW": NUM_ROW,
        "NUM_COL": NUM_COL,
        "TRAFFIC_FILE": iql_config.traffic_file,
        "ROADNET_FILE": "roadnet_{0}.json".format(road_net),
        "TRAFFIC_SEPARATE": iql_config.traffic_file,
        "OFFLINE_DATA_SIZE": in_args.data_size,
        "LIST_STATE_FEATURE": [
            "lane_num_waiting_vehicle_in",
            "lane_enter_running_part",
            "traffic_movement_pressure_queue_efficient"
        ],
        "DIC_REWARD_INFO": {
            "traffic_movement_pressure_queue_efficient": -0.25
        },
        "OFFLINE_METHOD": 'IQL',
        "LIST_INFO_FEATURE": [
            "lane_num_waiting_vehicle_in",
        ]
    }
    
    dic_path_extra = {
        "PATH_TO_MODEL": os.path.join("model", memo, iql_config.traffic_file + "_"
                                      + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_IQL"),
        "PATH_TO_WORK_DIRECTORY": iql_config.checkpoints_path,
        "PATH_TO_DATA": os.path.join("data", template, str(road_net)),
        "PATH_TO_ERROR": os.path.join("errors", memo),
        "PATH_TO_OFFLINE_DATA": os.path.join("offline_dataset", "Fixedtime", iql_config.traffic_file, in_args.data_size, "data")
    }
    dic_agent_conf = getattr(config, "DIC_BASE_AGENT_CONF")
    dic_traffic_env_conf = merge(config.dic_traffic_env_conf, dic_traffic_env_conf_extra)
    dic_path = merge(config.DIC_PATH, dic_path_extra)
        
    path_check(dic_path)
    copy_conf_file(dic_path, dic_agent_conf, dic_traffic_env_conf)
    copy_cityflow_file(dic_path, dic_traffic_env_conf)
    
    train(config=iql_config, dic_traffic_env_conf=dic_traffic_env_conf, dic_path=dic_path)
    
if __name__ == "__main__":
    in_args = parse_args()
    iql_main(in_args)