import os
import yaml
import argparse
from datetime import datetime

from fqf_iqn_qrdqn.env import make_pytorch_env
from fqf_iqn_qrdqn.agent import EEAgent

import ray
from ray import tune
import wandb
from ray.tune.integration.wandb import wandb_mixin

import time
import torch


current_path = os.getcwd()

@wandb_mixin
def run(config):
    if isinstance(args.cuda, int):
        device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
        torch.cuda.set_device(device)

    # Create environments.
    env = make_pytorch_env(args.env_id)
    test_env = make_pytorch_env(
        args.env_id, episode_life=False, clip_rewards=False)

    # Specify the directory to log.
    #name = args.config.split('/')[-1].rstrip('.yaml')
    time_local = datetime.now().strftime("%Y%m%d-%H%M%S")

    batch_size = config['batch_size']
    seed = config['seed']
    lr = config['lr']

    log_dir = os.path.join(
        'logs', args.env_id, f'EE-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')

    # Create the agent and run.
    agent = EEAgent(
        env=env, test_env=test_env, log_dir=log_dir, seed=seed,
        cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml)
    agent.run()


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config', type=str, default=os.path.join('config', 'EE.yaml'))
    parser.add_argument('--env_id', type=str, default='YarsRevengeNoFrameskip-v4') #Pong -> Tennis
    parser.add_argument('--cuda', type=int, default=None)
    args,_ = parser.parse_known_args()
    
    return args

if __name__ == '__main__':    
    start = time.time()
    ray.init()

    args = parser()
    ray_config = {
        'batch_size' : tune.grid_search([64]),
        # 'batch_size' : tune.grid_search([32,64]),
        'lr' : tune.grid_search([5e-5]), #25e-5
        # 'lr' : tune.grid_search([25e-5, 5e-5]),
        'seed' : tune.grid_search([234,567,789]), #234,567
        # 'seed' : tune.grid_search([234]),
        'args': args
        }

    path = os.path.join('config', 'EE.yaml')
    with open(path) as f:
        config_yaml = yaml.load(f, Loader=yaml.SafeLoader)

    #DLTV로 수정한 상태
    if config_yaml['double_q_learning']:
        # group_name = "double_EE" 
        group_name = "double_DLTV" 
    else:
        # group_name = "EE"
        group_name = "DLTV"


    ray_config.update({    
       'wandb': {
            "project":"DRL_{}".format(args.env_id),
            "group" : group_name
            }
        })
    print("GROUP NAME :", group_name)

    analysis = tune.run(run, config = ray_config,
         resources_per_trial={"cpu":2, "gpu":0.5},
         name= "{}-experiment".format(args.env_id),
         )

    finish = time.time()
    print("Time :", finish - start)


# current_path = os.getcwd()

# @wandb_mixin
# def run(config):
#     args = config['args']
#     os.chdir(current_path)
#     with open(args.config) as f:
#         config_yaml = yaml.load(f, Loader=yaml.SafeLoader)

#     if isinstance(args.cuda, int):
#         device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
#         torch.cuda.set_device(device)
#         print(torch.cuda.current_device())
#         import time
#         time.sleep(3)

#     # Create environments.
#     env = make_pytorch_env(args.env_id)
#     test_env = make_pytorch_env(
#         args.env_id, episode_life=False, clip_rewards=False)

#     # # Specify the directory to log.
#     # name = args.config.split('/')[-1].rstrip('.yaml')
#     time_local = datetime.now().strftime("%Y%m%d-%H%M%S")

#     batch_size = config['batch_size']
#     seed = config['seed']
#     lr = config['lr']

#     log_dir = os.path.join(
#         'logs', args.env_id, f'EE-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')        

#     # Create the agent and run.
#     agent = EEAgent(
#         env=env, test_env=test_env, log_dir=log_dir, seed=seed,
#         cuda=isinstance(args.cuda, int), batch_size=batch_size, lr=lr, **config_yaml)
#     agent.run()

# def parser():
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         '--config', type=str, default=os.path.join('config', 'EE.yaml'))
#     parser.add_argument('--env_id', type=str, default='YarsRevengeNoFrameskip-v4')
#     parser.add_argument('--cuda', type=int, default=None)
#     args,_ = parser.parse_known_args()
    
#     return args

# if __name__ == '__main__':    
#     start = time.time()
#     ray.init()

#     if config_yaml['double_q_learning']:
#         group_name = "double_q_QRDQN"
#     else:
#         group_name = "QRDQN"

#     args = parser()
#     ray_config = {
#         'batch_size' : tune.grid_search([32]),
#         'lr' : tune.grid_search([25e-5, 5e-4]),
#         'seed' : tune.grid_search([234,567,890]),
#         'args': args,
#         'wandb': {
#             "project":"DRL_{}".format(args.env_id),
#             "group" : group_name
#             }
#         }

#     analysis = tune.run(run, config = ray_config,
#          resources_per_trial={"cpu":2, "gpu":0.5}, 
#          name= "{}-experiment".format(args.env_id),
#          )

#     finish = time.time()
#     print("Time :", finish - start)

# """
#     config를 yaml파일로부터 받아와서 logs에 저장.
# """
