import sys
import time

import pygame


import argparse
from configparser import ConfigParser

from agents import *
from envs import *
from utils import *

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))


def parse_args():
    conf_parser = argparse.ArgumentParser(add_help=False)
    conf_parser.add_argument("-c", "--conf_file",
                             help="Specify config file", metavar="FILE")
    args, remaining_argv = conf_parser.parse_known_args()

    defaults = {}
    if args.conf_file:
        config = ConfigParser()
        config.read([args.conf_file])
        defaults |= dict(config.items("DEFAULT"))

    # Dynamically add arguments from the configuration file
    parser = argparse.ArgumentParser(parents=[conf_parser])
    for key, value in defaults.items():
        # Use the key from the config file as the argument name
        parser.add_argument(f'--{key}', default=value)

    parser.set_defaults(**defaults)
    args = parser.parse_args(remaining_argv)
    args.conf_file = conf_parser.parse_known_args()[0].conf_file

    # Transform args into a SectionProxy
    config_proxy = ConfigParser()
    # config_proxy.add_section('DEFAULT')
    for key, value in vars(args).items():
        config_proxy.set('DEFAULT', key, str(value))

    return config_proxy["DEFAULT"]


def main():
    default_config = parse_args()
    print(dict(default_config))

    env_id = default_config['EnvID']
    env_type = default_config['EnvType']

    if env_type == 'mario':
        env = BinarySpaceToDiscreteSpaceEnv(gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
    elif env_type == 'atari':
        env = gym.make(env_id)
    else:
        raise NotImplementedError
    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in env_id:
        output_size -= 1

    env.close()

    # model_path = 'models/{}_zero_int_below_10.model'.format(env_id)
    # predictor_path = 'models/{}_zero_int_below_10.pred'.format(env_id)
    # target_path = 'models/{}_zero_int_below_10.target'.format(env_id)

    model_path = "model"
    predictor_path = "pred"
    target_path = "target"

    use_cuda = False
    use_gae = default_config.getboolean('UseGAE')
    use_noisy_net = default_config.getboolean('UseNoisyNet')

    lam = float(default_config['Lambda'])
    num_worker = 1

    num_step = int(default_config['NumStep'])

    ppo_eps = float(default_config['PPOEps'])
    epoch = int(default_config['Epoch'])
    mini_batch = int(default_config['MiniBatch'])
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = float(default_config['LearningRate'])
    entropy_coef = float(default_config['Entropy'])
    gamma = float(default_config['Gamma'])
    clip_grad_norm = float(default_config['ClipGradNorm'])


    agent = RNDAgent

    agent = agent(
        input_size,
        output_size,
        num_worker,
        num_step,
        gamma,
        lam=lam,
        learning_rate=learning_rate,
        ent_coef=entropy_coef,
        clip_grad_norm=clip_grad_norm,
        epoch=epoch,
        batch_size=batch_size,
        ppo_eps=ppo_eps,
        use_cuda=use_cuda,
        use_gae=use_gae,
        use_noisy_net=use_noisy_net
    )

    print('Loading Pre-trained model....')
    print(f"Loading: {model_path}")
    if use_cuda:
        agent.model.load_state_dict(torch.load(model_path))
        agent.rnd.predictor.load_state_dict(torch.load(predictor_path))
        agent.rnd.target.load_state_dict(torch.load(target_path))
    else:
        agent.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        agent.rnd.predictor.load_state_dict(torch.load(predictor_path, map_location='cpu'))
        agent.rnd.target.load_state_dict(torch.load(target_path, map_location='cpu'))
    print('End load...')
    return agent


agent = main()

print(agent)

env = AtariEnvironment(env_id="MontezumaRevengeNoFrameskip-v4",
                       is_render=False,
                       env_idx=0,
                       child_conn=None,
                       sticky_action=True,
                       p=0.25,
                       life_done=False,
                       use_state_loading=True,
                       load_room=10,
                       should_calc_additional_metrics=True
                       )


# Set up Pygame for keyboard input
pygame.init()


WIDTH=600
HEIGHT=480
screen = pygame.display.set_mode((WIDTH, HEIGHT))


# Main loop
target_fps = 20  # Set your target frame rate here

# Initialize action to None
action = None
running = True
last_key = None
rall=0


state = env.reset()
states = np.zeros([1, 4, 84, 84])
time.sleep(3)

# room_path ="room_10_lives_0_return_2800.0.pkl"
# room = pickle.load(open("room_10_lives_0_return_1000.0.pkl", "rb"))
# env.env.reset(load_state_path=room_path)

while running:

    actions, value_ext, value_int, policy = agent.get_action(np.float32(states) / 255.0)

    action = actions[0]
    if action is not None:
        next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
        s, r, d, rd, lr, cr = env.step(action)
        rall += r
        next_states = s.reshape([1, 4, 84, 84])
        next_obs = s[3, :, :].reshape([1, 1, 84, 84])
        states = next_states[:, :, :, :]


        screen_data = unwrap(env).ale.getScreenRGB()

        # Convert screen_data to a Pygame surface
        screen_surface = pygame.surfarray.make_surface(screen_data)

        screen_surface = pygame.transform.scale(screen_surface, (HEIGHT, WIDTH))

        # Rotate the screen by 90 degrees clockwise
        screen_surface = pygame.transform.rotate(screen_surface, -90)

        # Mirror the screen by the Y-axis (flip horizontally)
        screen_surface = pygame.transform.flip(screen_surface, True, False)

        # Blit the screen surface onto the display
        screen.blit(screen_surface, (0, 0))

        # Process the screen_data as needed
        # For example, you can display the game screen using a library like OpenCV
        # print("Action:", action, "Reward:", r)
        if unwrap(env).ale.game_over():
            unwrap(env).ale.reset_game()

    # Refresh the screen
    pygame.display.flip()

    # Control the frame rate to achieve 20 FPS
    pygame.time.delay(1000 // target_fps)

# Clean up
unwrap(env).ale.reset_game()
pygame.quit()
sys.exit()
