import argparse
import datetime
import time

from tqdm import tqdm
from pathlib import Path

import os.path
from configparser import ConfigParser

from agents import *
from behavioral_cloning.dataset import TrajectoryDataset
from envs import *
from utils import *
from torch.multiprocessing import Pipe

from tensorboardX import SummaryWriter
import wandb
import numpy as np

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))


def parse_args():
    conf_parser = argparse.ArgumentParser(add_help=False)
    conf_parser.add_argument("-c", "--conf_file",
                             help="Specify config file", metavar="FILE")
    args, remaining_argv = conf_parser.parse_known_args()

    defaults = {}
    if args.conf_file:
        config = ConfigParser()
        config.read([args.conf_file])
        defaults |= dict(config.items("DEFAULT"))

    # Dynamically add arguments from the configuration file
    parser = argparse.ArgumentParser(parents=[conf_parser])
    for key, value in defaults.items():
        # Use the key from the config file as the argument name
        parser.add_argument(f'--{key}', default=value)

    parser.set_defaults(**defaults)
    args = parser.parse_args(remaining_argv)
    args.conf_file = conf_parser.parse_known_args()[0].conf_file

    # Transform args into a SectionProxy
    config_proxy = ConfigParser()
    # config_proxy.add_section('DEFAULT')
    for key, value in vars(args).items():
        config_proxy.set('DEFAULT', key, str(value))

    return config_proxy["DEFAULT"]

def main():
    default_config = parse_args()
    print(dict(default_config))

    rooms_to_zero = [int(x) for x in default_config["ZeroIntrinsicRooms"].split(",")]

    env_id = default_config["EnvID"]
    env_type = default_config["EnvType"]

    if env_type == "mario":
        env = BinarySpaceToDiscreteSpaceEnv(gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
    elif env_type == "atari":
        env = gym.make(env_id)
    else:
        raise NotImplementedError
    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if "Breakout" in env_id:
        output_size -= 1

    env.close()

    is_render = False

    # models paths
    models_path = os.path.join(ROOT_DIR, "models")
    os.makedirs(models_path, exist_ok=True)

    is_load_model = default_config.getboolean("LoadModel")
    model_load_dir = os.path.join(models_path, f"{env_id}", default_config["LoadModelFolder"])
    load_model_suffix = default_config["LoadModelSuffix"]

    # To find the latest checkpoint
    if load_model_suffix == "find":
        models = os.listdir(model_load_dir)
        model_iter = max([int(name.split("_")[-2]) for name in models])
        load_model_suffix = [model for model in models if str(model_iter) in model][0].split(".")[0]
    save_model_name_suffix = default_config["SaveModelNameSuffix"]

    # To have more explicit naming of checkpoints
    if is_load_model:
        save_model_name_suffix = f"{save_model_name_suffix}_from_{load_model_suffix}"

    # To reload checkpoints etc
    run_id = None
    resume = default_config.getboolean("Resume")
    if resume:
        model_load_dir =  os.path.join(models_path, f"{env_id}", save_model_name_suffix)
        models = os.listdir(model_load_dir)
        model_iter = max([int(name.split("_")[-2]) for name in models])
        load_model_suffix = [model for model in models if str(model_iter) in model][0].split(".")[0]
        run_id = default_config["RunId"]

    model_save_dir = os.path.join(models_path, f"{env_id}", save_model_name_suffix)
    os.makedirs(model_save_dir, exist_ok=True)

    model_path = os.path.join(model_load_dir, f"{load_model_suffix}.model")
    predictor_path = os.path.join(model_load_dir, f"{load_model_suffix}.pred")
    target_path = os.path.join(model_load_dir, f"{load_model_suffix}.target")

    default_config["LoadModelSuffix"] = load_model_suffix

    wandb.init(
        name=default_config["RunName"],
        group=default_config["RunGroup"],
        config=dict(default_config),
        id=run_id,
        resume=resume,
        reinit=True,
        project="montezuma_finetuning",
        entity="<name>",
    )
    wandb.tensorboard.patch(root_logdir="./tensorboard", tensorboard_x=True)

    writer = SummaryWriter(logdir=os.path.join(ROOT_DIR, "runs", datetime.datetime.now().strftime("%b%d_%H-%M-%S")+default_config['LoadModelFolder']))

    use_cuda = default_config.getboolean("UseGPU")
    use_gae = default_config.getboolean("UseGAE")
    use_noisy_net = default_config.getboolean("UseNoisyNet")

    lam = float(default_config["Lambda"])
    num_worker = int(default_config["NumEnv"])

    num_step = int(default_config["NumStep"])

    ppo_eps = float(default_config["PPOEps"])
    epoch = int(default_config["Epoch"])
    mini_batch = int(default_config["MiniBatch"])
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = float(default_config["LearningRate"])
    entropy_coef = float(default_config["Entropy"])
    gamma = float(default_config["Gamma"])
    int_gamma = float(default_config["IntGamma"])
    clip_grad_norm = float(default_config["ClipGradNorm"])
    ext_coef = float(default_config["ExtCoef"])
    int_coef = float(default_config["IntCoef"])

    sticky_action = default_config.getboolean("StickyAction")
    action_prob = float(default_config["ActionProb"])
    life_done = default_config.getboolean("LifeDone")

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    pre_obs_norm_step = int(default_config["ObsNormStep"])
    discounted_reward = RewardForwardFilter(int_gamma)

    agent = RNDAgent

    zero_intrinsic = default_config.getboolean("ZeroIntrinsic")

    if default_config["EnvType"] == "atari":
        env_type = AtariEnvironment
    elif default_config["EnvType"] == "mario":
        env_type = MarioEnvironment
    else:
        raise NotImplementedError

    agent = agent(
        input_size,
        output_size,
        num_worker,
        num_step,
        gamma,
        lam=lam,
        learning_rate=learning_rate,
        ent_coef=entropy_coef,
        clip_grad_norm=clip_grad_norm,
        epoch=epoch,
        batch_size=batch_size,
        ppo_eps=ppo_eps,
        use_cuda=use_cuda,
        use_gae=use_gae,
        use_noisy_net=use_noisy_net,
        use_kl=default_config.getboolean('UseKL'),
        kl_coef=float(default_config['KLCoef']),
        writer=writer
    )

    dataset=None
    if default_config.getboolean('UseKL'):
        input_size_ds = (84, 84)
        output_size_ds = 18

        dataset_path = Path(default_config['KLTrajectoriesPath'])
        trajectories = []
        n_total_files = len(os.listdir(dataset_path))
        for file_path in tqdm(dataset_path.iterdir(), desc="Loading Trajectories", total=n_total_files):
            try:
                with open(file_path, "rb") as f:
                    loaded_trajectories = pickle.load(f)
                trajectories.append(loaded_trajectories)
            except Exception as e:
                print(f"Error loading trajectories from file: {file_path}\nbecause {str(e)}")

        dataset = TrajectoryDataset(
            trajectories=trajectories,
            state_dim=input_size_ds,
            act_dim=output_size_ds,
            batch_size=batch_size,
            use_kl=True
        )

    if is_load_model:
        print(f"load model: {model_path}...")
        if use_cuda:
            agent.model.load_state_dict(torch.load(model_path))
            agent.rnd.predictor.load_state_dict(torch.load(predictor_path))
            agent.rnd.target.load_state_dict(torch.load(target_path))
        else:
            agent.model.load_state_dict(torch.load(model_path, map_location="cpu"))
            agent.rnd.predictor.load_state_dict(torch.load(predictor_path, map_location="cpu"))
            agent.rnd.target.load_state_dict(torch.load(target_path, map_location="cpu"))
        print("load finished!")

    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = env_type(
            env_id,
            is_render,
            idx,
            child_conn,
            sticky_action=sticky_action,
            p=action_prob,
            life_done=life_done,
            writer=writer,
            use_state_loading=default_config.getboolean("UseStateLoading"),
            load_room=default_config["LoadRoom"],
            room_saving=default_config["RoomSaving"],
            should_calc_additional_metrics=True
        )
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([num_worker, 4, 84, 84])

    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_env_idx = 0
    sample_i_rall = 0
    global_update = 0
    if default_config.getboolean("Resume") and is_load_model:
        global_update = int(load_model_suffix.split("_")[2])
    global_step = 0

    # normalize obs
    print("Start to initailize observation normalization parameter.....")
    next_obs = []
    for step in range(num_step * pre_obs_norm_step):
        actions = np.random.randint(0, output_size, size=(num_worker,))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for parent_conn in parent_conns:
            s, r, d, rd, lr, cr = parent_conn.recv()
            next_obs.append(s[3, :, :].reshape([1, 84, 84]))

        if len(next_obs) % (num_step * num_worker) == 0:
            next_obs = np.stack(next_obs)
            obs_rms.update(next_obs)
            next_obs = []
    print("End to initalize...")


    while True:
        time_start = time.time()

        (
            total_state,
            total_reward,
            total_done,
            total_next_state,
            total_action,
            total_int_reward,
            total_next_obs,
            total_ext_values,
            total_int_values,
            total_policy,
            total_policy_np,
        ) = ([], [], [], [], [], [], [], [], [], [], [])
        global_step += num_worker * num_step
        global_update += 1

        if global_update >= int(default_config['MaxGlobalUpdate']):
            return 0

        # Step 1. n-step rollout
        for _ in range(num_step):
            actions, value_ext, value_int, policy = agent.get_action(np.float32(states) / 255.0)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_states, rewards, dones, real_dones, log_rewards, next_obs, current_rooms = [], [], [], [], [], [], []
            for parent_conn in parent_conns:
                s, r, d, rd, lr, cr = parent_conn.recv()
                next_states.append(s)
                rewards.append(r)
                dones.append(d)
                real_dones.append(rd)
                log_rewards.append(lr)
                next_obs.append(s[3, :, :].reshape([1, 84, 84]))
                current_rooms.append(cr)

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)
            next_obs = np.stack(next_obs)
            current_rooms = np.stack(current_rooms)

            # total reward = int reward + ext Reward
            intrinsic_reward = agent.compute_intrinsic_reward(
                ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5)
            )
            intrinsic_reward = np.hstack(intrinsic_reward)

            if zero_intrinsic:
                mask = np.isin(current_rooms, rooms_to_zero)
                intrinsic_reward[mask] = 0

            sample_i_rall += intrinsic_reward[sample_env_idx]

            total_next_obs.append(next_obs)
            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_policy.append(policy)
            total_policy_np.append(policy.cpu().numpy())

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_idx]

            sample_step += 1
            if real_dones[sample_env_idx]:
                sample_episode += 1
                writer.add_scalar("data/reward_per_epi", sample_rall, sample_episode)
                writer.add_scalar("data/reward_per_rollout", sample_rall, global_update)
                writer.add_scalar("data/step", sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = agent.get_action(np.float32(states) / 255.0)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape([-1, 4, 84, 84])
        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_next_obs = np.stack(total_next_obs).transpose([1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84])
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_policy = np.vstack(total_policy_np)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array(
            [discounted_reward.update(reward_per_step) for reward_per_step in total_int_reward.T]
        )
        mean, std, count = np.mean(total_reward_per_env), np.std(total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std**2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar("data/int_reward_per_epi", np.sum(total_int_reward) / num_worker, sample_episode)
        writer.add_scalar("data/int_reward_per_rollout", np.sum(total_int_reward) / num_worker, global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar("data/max_prob", softmax(total_logging_policy).max(1).mean(), sample_episode)
        writer.add_scalar("data/global_update", global_update, sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward, total_done, total_ext_values, gamma, num_step, num_worker)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(
            total_int_reward, np.zeros_like(total_int_reward), total_int_values, int_gamma, num_step, num_worker
        )

        # add ext adv and int adv
        total_adv = int_adv * int_coef + ext_adv * ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        agent.train_model(
            np.float32(total_state) / 255.0,
            ext_target,
            int_target,
            total_action,
            total_adv,
            ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5),
            total_policy,
            dataset=dataset,
            global_update=global_update
        )

        saving_freq = int(default_config['SavingFreq']) or 40
        if global_update % saving_freq == 0:
            iter = global_update // saving_freq
            print("Now Global Step :{}".format(global_step))
            torch.save(agent.model.state_dict(), os.path.join(model_save_dir, f"gl_st_{global_update}_iter_{iter}_.model"))
            torch.save(agent.rnd.predictor.state_dict(), os.path.join(model_save_dir, f"gl_st_{global_update}_iter_{iter}_.pred"))
            torch.save(agent.rnd.target.state_dict(), os.path.join(model_save_dir, f"gl_st_{global_update}_iter_{iter}_.target"))

        time_end = time.time()
        print(f"Time of one iter {time_end-time_start}")

if __name__ == "__main__":
    main()
