import argparse
import numpy
import os
import sys
import itertools
import copy
from noda_sac_main import noda_main
import pynvml
import time
import copy
import shutil
import pdb


def get_search_space():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8])
    parser.add_argument('--epochs', nargs='+', default=[250])
    parser.add_argument('--lat-noda', nargs='+', default=[128])
    parser.add_argument('--hid-noda-ae', nargs='+', default=[256])
    parser.add_argument('--hid-noda-ode', nargs='+', default=[128])
    parser.add_argument('--steps-per-epoch', nargs='+', default=[4000])
    parser.add_argument('--model-step', nargs='+', default=[1])
    parser.add_argument('--update-action-turns', nargs='+', default=[0])
    parser.add_argument('--update-model-interval', nargs='+', default=[1])
    parser.add_argument('--model-data-ratio', nargs='+', default=[1.0])
    parser.add_argument('--explore-lr', nargs='+', default=[0.000])
    parser.add_argument('--update-every', nargs='+', default=[50])
    parser.add_argument('--noise', nargs='+', default=[0])
    parser.add_argument('--use-ode', nargs='+', default=[1])

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--seed-init', type=int, default=0)
    parser.add_argument('--seed-num', type=int, default=8)
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--exp-name', type=str, default='')
    parser.add_argument('--device', type=str, default='cuda')
    search_args = parser.parse_args()

    search_space = []
    keys = [key for key in search_args.__dict__.keys() if type(search_args.__dict__[key]) == list]
    values = []
    variable_keys = [key for key in search_args.__dict__.keys() if
                     type(search_args.__dict__[key]) == list and len(search_args.__dict__[key]) > 1]
    for key in keys:
        values.append(search_args.__dict__[key])
    for values in itertools.product(*tuple(values)):
        args = copy.deepcopy(search_args)
        args.command = ''
        for i in range(len(values)):
            exec('args.' + str(keys[i]) + ' = values[i]')
            if keys[i] == 'env':
                args.command += ' --' + str(keys[i]).replace('_', '-') + ' ' + get_env(values[i])
            else:
                args.command += ' --' + str(keys[i]).replace('_', '-') + ' ' + str(values[i])
        for key in variable_keys:
            exec('args.exp_name += "_" + str(key) + str(args.' + str(key) + ')')
        args.exp_name = args.exp_name[1:]
        args.command += ' --exp-name ' + args.exp_name
        search_space.append(args)
    return search_space


def get_env(index):
    env_list = ['InvertedPendulum-v2', 'HalfCheetah-v3', 'Hopper-v3', 'Walker2d-v3',
                'Ant-v3', 'Humanoid-v3', 'Swimmer-v3', 'Thrower-v2']
    return env_list[int(index) - 1]


def run_gpu(search_space=None):
    if search_space is None:
        search_space = get_search_space()
    job_index = 0
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    devices = list(range(device_count))
    seed_num = search_space[0].seed_num
    free_gpu_memory = [0] * device_count
    f = open('noda_sac.sh', 'w+')
    start_index = 0
    while job_index < len(search_space):
        for device in devices:
            handle = pynvml.nvmlDeviceGetHandleByIndex(device)
            free_gpu_memory[device] = pynvml.nvmlDeviceGetMemoryInfo(handle).free / 1048576
            if len(pynvml.nvmlDeviceGetComputeRunningProcesses(handle)) >= 4:
                free_gpu_memory[device] = 0
        try:
            target_device = free_gpu_memory.index(max(free_gpu_memory), start_index)
        except:
            target_device = free_gpu_memory.index(max(free_gpu_memory))
        if free_gpu_memory[target_device] > 3900:
            start_index = (start_index + 1) % len(devices)
            search_space[job_index].device = 'cuda:' + str(target_device)
            f.write('python noda_sac_main.py' + search_space[job_index].command + ' --device cuda:'
                    + str(target_device) + ' --seed ' + str(search_space[job_index].seed) + ' &\n')
            if search_space[job_index].seed < search_space[job_index].seed_init + seed_num - 1:
                search_space[job_index].seed += 1
            else:
                job_index += 1
                if job_index == len(search_space):
                    break
        time.sleep(0.1)
    f.close()


if __name__ == '__main__':
    run_gpu()
