import os
import subprocess
import yaml
import hashlib
import threading
import time
import datetime


def get_hash(obj):
    if type(obj) == dict:
        msg = str(sorted(args_dict.items()))
    else:
        msg = str(obj)
    
    obj_hash = hashlib.sha256(bytes(msg, 'utf-8')).hexdigest()[:8]

    return obj_hash

class GPUAllocator():
    def __init__(self, config_path='gpu_config.yaml'):
        self.config_path = config_path
        self.lock = threading.Lock()

        self.id_to_pair = {}
        self.gpu_to_runs = {}
        self.id_in_use = {}

        self._sync_config()
    
    def allocate(self):
        self._sync_config()

        for id in self.id_to_pair:
            self.lock.acquire()

            if id in self.id_to_pair and not self.id_in_use.get(id, False):
                self.id_in_use[id] = True
                self.lock.release()

                return id
            else:
                self.lock.release()

        return None

    def free(self, id):
        self.lock.acquire()
        if id in self.id_to_pair:
            self.id_in_use[id] = False
        else:
            del self.id_in_use[id]
        self.lock.release()
    
    def get_gpu(self, id):
        if self.id_in_use[id]:
            return self.id_to_pair[id][0]
        else:
            return None
    

    def _sync_config(self):
        with open(self.config_path, 'r') as in_file:
            self.gpu_to_runs = yaml.load(in_file, Loader=yaml.FullLoader)
        
        id_to_pair = {}
        for gpu in self.gpu_to_runs:
            for run in range(self.gpu_to_runs[gpu]):
                pair = (gpu, run)
                id = self._get_id_from_pair(pair)
                id_to_pair[id] = pair
                if id not in self.id_in_use:
                    self.id_in_use[id] = False
        
        self.id_to_pair = id_to_pair
    
    def _get_id_from_pair(self, id):
        return get_hash(get_hash(str(id[0])) + get_hash(str(id[1])))


class ExperimentManager():
    def __init__(self):
        self.gpu_allocator = GPUAllocator()
        self.id_to_thread = {}
    
    def queue(self, command):
        id = self.gpu_allocator.allocate()

        if id == None:
            return False

        thread = threading.Thread(target=self._run_subprocess, args=(command, id))
        thread.start()

        self.id_to_thread[id] = thread

        return True

    def finish(self):
        for id in self.id_to_thread:
            self.id_to_thread[id].join()
    
    def _run_subprocess(self, command, id):
        gpu = self.gpu_allocator.get_gpu(id)

        environ = os.environ.copy()
        environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
        
        # subprocess.call(command, env=environ)
        subprocess.call(command, env=environ, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        self.gpu_allocator.free(id)
    


# ntt: num_task_transitions, gs: grad_steps, ab: action_budget
def load_run_data(run_data_path):
    with open(run_data_path, 'r') as in_file:
        run_data = yaml.load(in_file, Loader=yaml.FullLoader)
    
    return run_data


def build_command(args_dict):
    command = []

    # Add executable
    command += ['python', '-m', 'train']

    # Add environment configuration
    env = args_dict['env']
    if env == 'Humanoid':
        command += ['@scripts/args_humanoid.txt']
    elif env == 'BallBalance':
        command += ['@scripts/args_ball_balance.txt']
    elif env == 'Ant':
        command += ['@scripts/args_ant.txt']
    elif env == 'Ingenuity':
        command += ['@scripts/args_ingenuity.txt']
    elif env == 'Anymal':
        command += ['@scripts/args_anymal.txt']
    else:
        raise ValueError('Unknown env: {}'.format(env))

    # Add name
    command += ['--logdir_suffix', args_dict['experiment']]

    # Add seed
    command += ['--seed', str(args_dict['seed'])]

    # Add experiment-specific configuration
    experiment = args_dict['experiment']
    if experiment == '1_BC_C':
        command += [
            '--method', 'bc',
            '--order', 'C',
            '--updates_per_step', '0',
            '--no_pretrain_qrisk',
        ]
    elif experiment == '2_IBC_C':
        command += [
            '--method', 'ibc',
            '--order', 'C',
            '--updates_per_step', '0',
            '--no_pretrain_qrisk',
        ]
    elif experiment == '3_BC_UC':
        command += [
            '--method', 'bc',
            '--order', 'UC',
        ]
    elif experiment == '4_IBC_random':
        command += [
            '--method', 'ibc',
            '--allocation', 'random',
            '--no_pretrain_qrisk',
            '--action_budget', str(args_dict['action_budget']),
        ]
    elif experiment == '5_IBC_energy':
        raise NotImplementedError
    else:
        raise ValueError('Unknown experiment: {}'.format(experiment))

    # Add arguments
    command += ['--num_task_transitions', str(args_dict['num_task_transitions'])]
    command += ['--grad_steps', str(args_dict['grad_steps'])]

    # Make it multi-expert
    command += [
        '--num_envs', '10',
        '--num_humans', '3',
        '--num_players', '3',
    ]

    return command


if __name__ == '__main__':
    run_data_path = 'runs_config.yaml'
    run_data = load_run_data(run_data_path)

    commands = []
    for experiment in run_data:
        for env in run_data[experiment]:
            for seed in [1000, 2000, 3000]:
                args_dict = {
                    'experiment': experiment,
                    'env': env,
                    'seed': seed,
                    'num_task_transitions': run_data[experiment][env]['ntt'],
                    'grad_steps': run_data[experiment][env]['gs'],
                }

                if experiment == '4_IBC_random':
                    args_dict['action_budget'] = run_data[experiment][env]['ab']

                command = build_command(args_dict)
                commands.append(command)


    retry_freq = 60 # seconds

    em = ExperimentManager()

    print("Starting runs (time: {})".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    i = 0
    while True:
        if i >= len(commands):
            break
        
        if not em.queue(commands[i]):
            time.sleep(retry_freq)
        else:
            print(f'(ExperimentManager) executing command: {" ".join(commands[i])}')
            i += 1

    em.finish()