import os
import yaml
import argparse
from datetime import datetime

from fqf_iqn_qrdqn.env import make_pytorch_env
from fqf_iqn_qrdqn.agent import PQRAgent

import ray
from ray import tune
import wandb
from ray.tune.integration.wandb import wandb_mixin

import time
import torch


current_path = os.getcwd()

@wandb_mixin
def run(config):
    if isinstance(args.cuda, int):
        device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
        torch.cuda.set_device(device)

    # Create environments.
    env = make_pytorch_env(args.env_id)
    test_env = make_pytorch_env(
        args.env_id, episode_life=False, clip_rewards=False)

    # Specify the directory to log.
    #name = args.config.split('/')[-1].rstrip('.yaml')
    time_local = datetime.now().strftime("%Y%m%d-%H%M%S")

    batch_size = config['batch_size']
    seed = config['seed']
    lr = config['lr']

    log_dir = os.path.join(
        'logs', args.env_id, f'PQR-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')

    # Create the agent and run.
    agent = PQRAgent(
        env=env, test_env=test_env, log_dir=log_dir, seed=seed,
        cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml)
    agent.run()


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config', type=str, default=os.path.join('config', 'pqrdqn.yaml'))
    parser.add_argument('--env_id', type=str, default='PhoenixNoFrameskip-v4')
    parser.add_argument('--cuda', type=int, default=None)
    args,_ = parser.parse_known_args()
    
    return args

if __name__ == '__main__':    
    start = time.time()
    ray.init()

    args = parser()
    ray_config = {
        'batch_size' : tune.grid_search([64]),
        'lr' : tune.grid_search([5e-5]), #25e-5
        # 'seed' : tune.grid_search([456, 567, 678]), #234,567
        'seed' : tune.grid_search([234]),
        'args': args
        }

    path = os.path.join('config', 'pqrdqn.yaml')
    with open(path) as f:
        config_yaml = yaml.load(f, Loader=yaml.SafeLoader)

    #DLTV하게 바꾼 상태
    if config_yaml['double_q_learning']:
        # group_name = "double_q_MY"
        group_name = "PQR_double_q"
    else:
        # group_name = "MY"
        group_name = "PQR"

    if config_yaml['egreedy']:
        group_name += "/egreedy"

    ray_config.update({    
       'wandb': {
            "project":"DRL_{}".format(args.env_id),
            "group" : group_name
            }
        })
    print("GROUP NAME :", group_name)

    analysis = tune.run(run, config = ray_config,
         resources_per_trial={"cpu":8, "gpu":2}, #cpu 32개 gpu 8개
         name= "{}-experiment".format(args.env_id),
         )

    finish = time.time()
    print("Time :", finish - start)

