import os
import yaml
import argparse
from datetime import datetime

from fqf_iqn_qrdqn.env import make_pytorch_env
from fqf_iqn_qrdqn.agent import QRDQNAgent

import ray
from ray import tune
import wandb
from ray.tune.integration.wandb import wandb_mixin

import time
import torch
import gym
# import CustomNChain

current_path = os.getcwd()

@wandb_mixin
def run(config):
    if isinstance(args.cuda, int):
        device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
        torch.cuda.set_device(device)

    # Create environments.
    env = make_pytorch_env(args.env_id)
    test_env = make_pytorch_env(
        args.env_id, episode_life=False, clip_rewards=False)
    # env = gym.make('CustomNChain-v0')
    # test_env = gym.make('CustomNChain-v0')

    # Specify the directory to log.
    #name = args.config.split('/')[-1].rstrip('.yaml')
    time_local = datetime.now().strftime("%Y%m%d-%H%M%S")

    batch_size = config['batch_size']
    seed = config['seed']
    lr = config['lr']

    log_dir = os.path.join(
        'logs', args.env_id, f'qrdqn-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')

    # Create the agent and run.
    agent = QRDQNAgent(
        env=env, test_env=test_env, log_dir=log_dir, seed=seed,
        cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml)
    agent.run()


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config', type=str, default=os.path.join('config', 'qrdqn.yaml'))
        # '--config', type=str, default=os.path.join('config', 'lunar.yaml'))

    parser.add_argument('--env_id', type=str, default='EnduroNoFrameskip-v4')
    # parser.add_argument('--env_id', type=str, default='CustomNChain-v0')
    parser.add_argument('--cuda', type=int, default=None)
    args,_ = parser.parse_known_args()
    
    return args

if __name__ == '__main__':    
    start = time.time()
    ray.init()

    args = parser()
    ray_config = {
        'batch_size' : tune.grid_search([64]),
        'lr' : tune.grid_search([5e-5]), 
        'seed' : tune.grid_search([234,567,789])
        }

    path = os.path.join('config', 'qrdqn.yaml')
    with open(path) as f:
        config_yaml = yaml.load(f, Loader=yaml.SafeLoader)

    if config_yaml['double_q_learning']:
        group_name = "double_q_QRDQN"
    else:
        group_name = "QRDQN"

    ray_config.update({    
       'wandb': {
            "project":"DRL_{}".format(args.env_id),
            "group" : group_name
            }
        })
    print("GROUP NAME :", group_name)

    analysis = tune.run(run, config = ray_config,
         resources_per_trial={"cpu":5, "gpu":0.5},
         name= "{}-experiment".format(args.env_id),
         )

    finish = time.time()
    print("Time :", finish - start)
