#import DQN
#import C51
import gymnasium as gym
import argparse
import numpy as np
import os
import yaml
from datetime import datetime
from distributionalrl.env import make_pytorch_env
from distributionalrl.agent import IQNAgent
from distributionalrl.agent import FQFAgent
from distributionalrl.agent import QRDQNAgent
from distributionalrl.agent import MOGNAgent
import torch

# yaml tuple constructor
def tuple_constructor(loader, node):
    # Load the sequence of values from the YAML node
    values = loader.construct_sequence(node)
    # Return a tuple constructed from the sequence
    return tuple(values)

# Register the constructor with PyYAML
yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', 
tuple_constructor)



parser = argparse.ArgumentParser()
parser.add_argument('--config_iqn', type=str, default=os.path.join('config', 'iqn.yaml'))
parser.add_argument('--config_qrdqn', type=str, default=os.path.join('config', 'qrdqn.yaml'))
parser.add_argument('--config_fqf', type=str, default=os.path.join('config', 'fqf.yaml'))
parser.add_argument('--config_mogn', type=str, default=os.path.join('config', 'mogn.yaml'))

parser.add_argument('--algorithm', type=str)
parser.add_argument('--deterministic', action='store_false')
parser.add_argument('--env_id', type=str, default='SeaquestNoFrameskip-v4')

parser.add_argument('--cuda', action='store_false')
parser.add_argument('--seed', type=int, default=np.random.randint(0,1000))

args = parser.parse_args()


def train_torch(algorithm="IQN", env_id='BreakoutNoFrameskip-v4', is_atari=True, frame_stack=4):
    # Create environments.
    env = make_pytorch_env(env_id, episode_life=is_atari, is_atari=is_atari, frame_stack=frame_stack, is_stochastic=(not args.deterministic))
    test_env = make_pytorch_env(env_id, episode_life=False, clip_rewards=False, is_atari=is_atari, frame_stack=frame_stack, is_stochastic=(not args.deterministic))

    if(is_atari):
        env.seed(args.seed)
        test_env.seed(2**31-1-args.seed)
    # Create the agent and run.
    agent, config = None, None
    if(algorithm == "IQN"):
        # Specify the directory to log.
        name = args.config_iqn.split('/')[-1].rstrip('.yaml')
        time = datetime.now().strftime("%Y%m%d-%H%M")
        log_dir = os.path.join('logs', env_id, f'{name}-seed{args.seed}_stocha{not args.deterministic}-')
        
        with open(args.config_iqn) as f:
            config = yaml.load(f, Loader=yaml.SafeLoader)

        agent = IQNAgent(env=env, test_env=test_env, log_dir=log_dir, seed=args.seed, cuda=args.cuda, **config)
    elif(algorithm == "QRDQN"):
        # Specify the directory to log.
        name = args.config_qrdqn.split('/')[-1].rstrip('.yaml')
        time = datetime.now().strftime("%Y%m%d-%H%M")
        log_dir = os.path.join('logs', env_id, f'{name}-seed{args.seed}_stocha{not args.deterministic}-')

        with open(args.config_qrdqn) as f:
            config = yaml.load(f, Loader=yaml.SafeLoader)

        agent = QRDQNAgent(env=env, test_env=test_env, log_dir=log_dir, seed=args.seed, cuda=args.cuda, **config)
    elif(algorithm == "MOGN"):
        # Specify the directory to log.
        name = args.config_mogn.split('/')[-1].rstrip('.yaml')
        log_dir = os.path.join('logs', env_id, f'{name}-seed{args.seed}_stocha{not args.deterministic}-')

        with open(args.config_mogn) as f:
            config = yaml.load(f, Loader=yaml.SafeLoader)

        agent = MOGNAgent(env=env, test_env=test_env, log_dir=log_dir, seed=args.seed, cuda=args.cuda, **config)
    else:
        # Specify the directory to log.
        name = args.config_fqf.split('/')[-1].rstrip('.yaml')
        time = datetime.now().strftime("%Y%m%d-%H%M")
        log_dir = os.path.join('logs', env_id, f'{name}-seed{args.seed}_stocha{not args.deterministic}-')

        with open(args.config_fqf) as f:
            config = yaml.load(f, Loader=yaml.SafeLoader)

        agent = FQFAgent(env=env, test_env=test_env, log_dir=log_dir, seed=args.seed, cuda=args.cuda, **config)

    agent.run()

def main():
    train_torch(algorithm=args.algorithm, env_id=args.env_id, is_atari=True, frame_stack=4)

if __name__ == "__main__":
    main()
