import os.path as osp
import sys
import numpy as np

proj_path = osp.abspath(osp.dirname(__file__)).split('src')[0]
sys.path.append(proj_path + 'src')
import importlib

import argparse
import os
from functools import partial
from utils.exp_utils.summarizer import get_tuner
import subprocess



def gen_commands(module_name, exp_name, gpu_config, n_workers=1, tuner_generator=None, workers='1-1', re_run=False):
    def gen_single_dataset_commands(exp_name, dataset, gpu_list):
        # ! Step 1 find trials left
        (finished_trials, total_trials), _ = tuner_generator(dataset=dataset, exp_name=exp_name).check_running_status()
        print(f'Found {len(finished_trials)}/{total_trials} previous finished {exp_name} trials of finetune dataset={dataset}.')
        if re_run:
            finished_trials = []
            print('Rerunning experiments')

        trials_to_run = total_trials - len(finished_trials)
        total_workers, worker_id = [int(_) for _ in workers.split('-')]

        # ! Step 2: Generate start and end points
        if trials_to_run == 0:
            print('Experiments finished, no trial left to run.')
            return []
        if trials_to_run < total_workers:
            total_workers = trials_to_run
            print(f'Only {trials_to_run} trials left to run, assigning one to each worker.')
            if total_workers < worker_id:
                print(f'Too few trials left, no need for this worker to run')
                return []
        trials = float(trials_to_run / total_workers)
        assert worker_id > 0, 'worker_id must be a positive integer'
        toint = lambda x: int(round(x))
        if trials < len(gpu_list):  # Only a part of gpu recourses are need
            gpu_inds = np.random.choice(len(gpu_list), toint(trials)).tolist()
            gpu_list = [gpu_list[_] for _ in gpu_inds]
            print(f'Only {toint(trials)} processes are required on this device. Processes are ran on random selected GPU ids {gpu_list}')
        start_points = np.linspace(trials * (worker_id - 1), trials * worker_id, len(gpu_list) + 1)
        end_points = start_points[1:]
        start_points = start_points[:-1]  # remove last one

        # ! Step 3: Generate tune_commands
        command_list = [f'tu{gpu} -r{args.run_times} -m{module_name} -d{dataset} -x{exp_name} -s{toint(start_points[i])} -e{toint(end_points[i])}'
                        for i, gpu in enumerate(gpu_list)]
        command_list[0] += ' -b'
        return command_list

    if n_workers > 1:
        for dataset in gpu_config:
            gpu_config[dataset] = sum([gpu_config[dataset] for _ in range(n_workers)], [])

    return {dataset: gen_single_dataset_commands(exp_name, dataset, gpu_list)
            for dataset, gpu_list in gpu_config.items()}


def run_commands(commands, model_name, args):
    init_file, ignore_prev, stop = args.init_file, args.re_run, args.stop
    init_command = f'cd {proj_path}&& source {init_file}&& conda activate jbdada'
    cnt = 0
    for dataset, command_list in commands.items():
        for i, command in enumerate(command_list):
            cnt += 1
            screen_name = f'{model_name}-{dataset}-S{i}'
            # Delete previous existing screen
            os.system(f'if screen -list | grep -q {screen_name};\n then screen -S {screen_name} -X quit\nfi')
            if stop:
                print(f'Screen {screen_name} deleted!')
                continue
            else:
                os.system(f'screen -mdS {screen_name}\n')
                # Initialize and run
                os.system(f'screen -S "{screen_name}" -X stuff "{init_command}\r"')
                command += ' -i' if ignore_prev else ''
                os.system(f'screen -S "{screen_name}" -X stuff "{command}\r"')
                print(f'Screen {screen_name} created, command running: {command}')

    print(f'{cnt} {model_name} tune_commands.\n')


def generate_commands(exp_name, n_workers=1, workers='1-1', re_run=False):
    '''
    Each command set is composed of module_name and command key, separated by '_'.
    The module_name must match the module name in the tune folder, the command key (stored in the 'COMMAND_GENERATORS' dictionary) is the key to get the tune_commands.
    Returns: model_name and tune_commands to run
    '''
    module_name = exp_name.split('_')[0]

    module = importlib.import_module(f'tune.{module_name}')
    gpu_config = module.EXP_DICT[exp_name]['gpu_conf']
    tuner_generator = partial(get_tuner, model_settings=module.model_settings, EXP_DICT=module.EXP_DICT)

    commands = gen_commands(module_name, exp_name, gpu_config, n_workers, tuner_generator, workers, re_run)
    return module_name, exp_name, commands


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--exp_name', type=str, default='GraphHD_Debug')
    parser.add_argument('-i', '--init_file', type=str, default='others/yeslab.cf')
    parser.add_argument('-R', '--re_run', action="store_true", help='show log or not')
    parser.add_argument('-c', '--check_status_current', action="store_true", help='show running status of current exp only')
    parser.add_argument('-C', '--check_status_all', action="store_true", help='show running status of all running trials')
    parser.add_argument('-s', '--stop', action="store_true", help='stop the experiments')
    parser.add_argument('-n', '--n_process', type=int, default=1, help='how many process per gpu')
    parser.add_argument('-w', '--workers', type=str, default='1-1', help='total_workers and worker_id')
    parser.add_argument('-r', '--run_times', type=int, default=10, help='Run times')
    args = parser.parse_args()
    module_name, exp_name, commands = generate_commands(args.exp_name, args.n_process, args.workers, args.re_run)

    if not (args.check_status_current or args.check_status_all):
        run_commands(commands, exp_name, args)
    elif args.check_status_all:
        running_screens = str(subprocess.run(['screen', '-ls'], stdout=subprocess.PIPE).stdout).split('\\n')
        parse_screen = lambda m, s: f"{m}{'-'.join(s.split(m)[1].split('-S')[0].split('-')[:-1])}"
        exp_lists = [parse_screen(module_name, s) for s in running_screens if module_name in s]
        for exp_name in set(exp_lists):
            print(f'\nChecking status of {exp_name}')
            generate_commands(exp_name)
