import argparse
from configparser import ConfigParser

from agents import *
from envs import *
from utils import *
from torch.multiprocessing import Pipe

from tensorboardX import SummaryWriter
import wandb

import numpy as np
import pickle

LOAD_ROOM=10

def parse_args():
    conf_parser = argparse.ArgumentParser(add_help=False)
    conf_parser.add_argument("-c", "--conf_file",default="config_eval.conf",
                             help="Specify config file", metavar="FILE")
    args, remaining_argv = conf_parser.parse_known_args()

    defaults = {}
    if args.conf_file:
        config = ConfigParser()
        config.read([args.conf_file])
        defaults |= dict(config.items("DEFAULT"))

    # Dynamically add arguments from the configuration file
    parser = argparse.ArgumentParser(parents=[conf_parser])
    for key, value in defaults.items():
        # Use the key from the config file as the argument name
        parser.add_argument(f'--{key}', default=value)

    parser.set_defaults(**defaults)
    args = parser.parse_args(remaining_argv)
    args.conf_file = conf_parser.parse_known_args()[0].conf_file

    # Transform args into a SectionProxy
    config_proxy = ConfigParser()
    # config_proxy.add_section('DEFAULT')
    for key, value in vars(args).items():
        config_proxy.set('DEFAULT', key, str(value))

    return config_proxy["DEFAULT"]


def main():
    # print({section: dict(config[section]) for section in config.sections()})
    default_config = parse_args()
    print(dict(default_config))
    
    ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
        
    env_id = default_config['EnvID']
    env_type = default_config['EnvType']

    if env_type == 'mario':
        env = BinarySpaceToDiscreteSpaceEnv(gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
    elif env_type == 'atari':
        env = gym.make(env_id)
    else:
        raise NotImplementedError
    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in env_id:
        output_size -= 1

    env.close()

    is_render = False
    # model_path = 'models/{}_zero_int_below_10.model'.format(env_id)
    # predictor_path = 'models/{}_zero_int_below_10.pred'.format(env_id)
    # target_path = 'models/{}_zero_int_below_10.target'.format(env_id)
    
    load_model_name_suffix = default_config["LoadModelNameSuffix"]
    
    models_path = os.path.join(ROOT_DIR, "models")
    models_path = os.path.join(models_path, f"{env_id}", load_model_name_suffix)
    
    wandb.init(
        name=default_config["RunName"],
        group=default_config["RunGroup"],
        config=dict(default_config),
        reinit=True,
        resume=False,
        project="montezuma_eval_final",
        entity="<name>",
    )
    wandb.tensorboard.patch(root_logdir="./tensorboard", tensorboard_x=True)
    writer = SummaryWriter(flush_secs=1)
    
    
    
    use_cuda = True
    use_gae = default_config.getboolean('UseGAE')
    use_noisy_net = default_config.getboolean('UseNoisyNet')

    lam = float(default_config['Lambda'])
    num_worker = 1

    num_step = int(default_config['NumStep'])

    ppo_eps = float(default_config['PPOEps'])
    epoch = int(default_config['Epoch'])
    mini_batch = int(default_config['MiniBatch'])
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = float(default_config['LearningRate'])
    entropy_coef = float(default_config['Entropy'])
    gamma = float(default_config['Gamma'])
    clip_grad_norm = float(default_config['ClipGradNorm'])

    sticky_action = False
    action_prob = float(default_config['ActionProb'])
    life_done = default_config.getboolean('LifeDone')

    agent = RNDAgent

    if default_config['EnvType'] == 'atari':
        env_type = AtariEnvironment
    elif default_config['EnvType'] == 'mario':
        env_type = MarioEnvironment
    else:
        raise NotImplementedError

    agent = agent(
        input_size,
        output_size,
        num_worker,
        num_step,
        gamma,
        lam=lam,
        learning_rate=learning_rate,
        ent_coef=entropy_coef,
        clip_grad_norm=clip_grad_norm,
        epoch=epoch,
        batch_size=batch_size,
        ppo_eps=ppo_eps,
        use_cuda=use_cuda,
        use_gae=use_gae,
        use_noisy_net=use_noisy_net
    )

    load_room = LOAD_ROOM
    files_path = "game_states"
    room_states_paths = [os.path.join(files_path, elem) for elem in os.listdir(files_path) if
                         f"room_{load_room}" in elem]
    rooms = []
    legit_paths = []
    for room_path in room_states_paths:
        env.env.reset(load_state_path=room_path)
        env.env.step(env.env.action_space.sample())
        current_room = env.env.get_current_room()
        print(f"{room_path}\nCurrent room: {current_room}\n")
        rooms.append(current_room)
        if current_room != 1:
            legit_paths.append(room_path)

    print(f"Possible states: {np.unique(rooms, return_counts=True)}")

    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = env_type(
            env_id,
            is_render,
            idx,
            child_conn,
            sticky_action=sticky_action,
            p=action_prob,
            life_done=life_done,
            writer=writer,
            use_state_loading=default_config.getboolean("UseStateLoading"),
            load_room=default_config["LoadRoom"],
            room_saving=default_config["RoomSaving"],

        )
        # Make sure we take only legit state checkpoints
        work.env.possible_loading_states = legit_paths
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    for iter in range(1,int(default_config['NumberOfCheckpoints']), int(default_config['StepOfCheckpoints'])):
        model_path = os.path.join(models_path, f"gl_st_{iter*25}_iter_{iter}_.model")
        predictor_path = os.path.join(models_path, f"gl_st_{iter*25}_iter_{iter}_.pred")
        target_path = os.path.join(models_path, f"gl_st_{iter*25}_iter_{iter}_.target")

        print('Loading Pre-trained model....')
        if use_cuda:
            agent.model.load_state_dict(torch.load(model_path))
            agent.rnd.predictor.load_state_dict(torch.load(predictor_path))
            agent.rnd.target.load_state_dict(torch.load(target_path))
        else:
            agent.model.load_state_dict(torch.load(model_path, map_location='cpu'))
            agent.rnd.predictor.load_state_dict(torch.load(predictor_path, map_location='cpu'))
            agent.rnd.target.load_state_dict(torch.load(target_path, map_location='cpu'))
        print('End load...')

        states = np.zeros([num_worker, 4, 84, 84])

        steps = 0
        rall = 0
        log_rall = 0
        rd = False
        intrinsic_reward_list = []
        while not rd:
            steps += 1
            actions, value_ext, value_int, policy = agent.get_action(np.float32(states) / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
            for parent_conn in parent_conns:
                s, r, d, rd, lr, cr = parent_conn.recv()
                rall += r
                log_rall += lr
                next_states = s.reshape([1, 4, 84, 84])
                next_obs = s[3, :, :].reshape([1, 1, 84, 84])

            print(f"Current room: {cr}")
            # total reward = int reward + ext Reward
            intrinsic_reward = agent.compute_intrinsic_reward(next_obs)
            intrinsic_reward_list.append(intrinsic_reward)
            states = next_states[:, :, :, :]

            if rd:
                intrinsic_reward_list = (intrinsic_reward_list - np.mean(intrinsic_reward_list)) / np.std(
                    intrinsic_reward_list)
                with open('int_reward', 'wb') as f:
                    pickle.dump(intrinsic_reward_list, f)
                writer.add_scalar("data/reward_per_epi", log_rall, iter)
                writer.add_scalar("data/step", steps, iter)
                steps = 0
                rall = 0
                log_rall = 0


if __name__ == '__main__':
    main()
