import json
import os
import itertools
import copy
from .trainer import Config
import tempfile

import subprocess

FOLDER_KEYS = []

class Experiment(dict):

    def __init__(self, base=None, name=None):
        super().__init__()
        self._name = name
        self.base_config = Config.load(base)

    @property
    def name(self):
        return self._name

    @classmethod
    def load(cls, path):
        name = os.path.splitext(os.path.basename(path))[0]
        with open(path, 'r') as fp:
            data = json.load(fp)
        # Run formatting checks
        assert 'base' in data, "Did not supply a base config"
        base_config = data['base']
        del data['base'] # Remove the base configuration

        for k, v in data.items():
            assert isinstance(k, str)
            assert isinstance(v, list)
        experiment = cls(base=base_config, name=name)
        experiment.update(data)
        return experiment

    def get_variants(self):
        variants = itertools.product(*[val for val in self.values()])
        variants = [{key:variant[i] for i, key in enumerate(self.keys())} for variant in variants]
        return variants

    def generate_configs_and_names(self):
        variants = self.get_variants()
        configs = []
        for i, variant in enumerate(variants):
            config = self.base_config.copy()
            name = ""
            remove_trailing_underscore = False
            for k, v in variant.items():
                config_path = k.split('.')
                config_dict = config
                while len(config_path) > 1:
                    if not config_path[0] in config_dict:
                        raise ValueError("Experiment specified key not in config: " + str(k))
                    config_dict = config_dict[config_path[0]]
                    config_path.pop(0)
                if not config_path[0] in config_dict:
                        raise ValueError("Experiment specified key not in config: " + str(k))
                config_dict[config_path[0]] = v
                
                if k in FOLDER_KEYS:
                    name = os.path.join(v, name)
                elif len(self[k]) > 1:
                    # Add it to the path name if it is different for each run.
                    if isinstance(v, str):
                        str_val = v
                    elif isinstance(v, int) or isinstance(v, float) or isinstance(v, bool) or v is None:
                        str_val = str(v)
                    elif isinstance(v, list):
                        str_val = v[-1]
                        assert isinstance(str_val, str)
                    else:
                        raise ValueError("Could not convert config value to str.")

                    name += str(config_path[0]) + '-' + str_val + '_'
                    remove_trailing_underscore = True

            if remove_trailing_underscore:
                name = name[:-1]
            name = os.path.join(self.name, name)    
            configs.append((name, config))
        
        return configs

class Hardware(object):

    def __init__(self, gpus=None, cpus=None):
        if cpus is None:
            cpus = ['0-' + str(os.cpu_count())]
        assert isinstance(gpus, list) or gpus is None, "GPUs must be a list of ints or None."
        assert isinstance(cpus, list), "CPUs must be a list"
        self.cpu_list = []
        for cpu_item in cpus:
            if isinstance(cpu_item, str) and '-' in cpu_item:
                # We have a CPU range
                cpu_min, cpu_max = cpu_item.split('-')
                cpu_min, cpu_max = int(cpu_min), int(cpu_max)
                self.cpu_list.extend(list(range(cpu_min, cpu_max)))
            else:
                self.cpu_list.append(int(cpu_item))
        self.gpu_list = gpus

    def stripe(self, num_jobs):
        # Returns an iterator over (cpu, gpu) hardware tuples.
        cores_per_job = len(self.cpu_list) // num_jobs
        hardware_configs = []
        for i in range(num_jobs):
            cpu = self.cpu_list[i*cores_per_job:(i+1)*cores_per_job]
            if self.gpu_list is None:
                gpu = None
            else:
                gpu = self.gpu_list[i % len(self.gpu_list)]
            hardware_configs.append((cpu, gpu))
        return hardware_configs

class Sweeper(object):

    def __init__(self, experiment, hardware, path):
        self.experiment = experiment
        self.hardware = hardware
        self.path = path

    def run(self):
        configs = self.experiment.generate_configs_and_names()
        hardware_stripe = self.hardware.stripe(len(configs))
        run_path = os.path.join(os.path.dirname(__file__), 'train_subproc.py')
        processes = []
        for i, ((name, config), (cpus, gpus)) in enumerate(zip(configs, hardware_stripe)):
            _, config_path = tempfile.mkstemp(text=True, prefix='config', suffix='.json')
            config.save(config_path)
            command_list = [
                'taskset', '-c', ','.join([str(c) for c in cpus]),
                'python', run_path,
                '--config', config_path,
                '--save-path', os.path.join(self.path, name)
                ]

            if gpus is None:
                command_list.extend(['--device', 'cpu'])
                env=None
            else:
                command_list.extend(['--device', 'cuda'])
                env = os.environ
                env["CUDA_VISIBLE_DEVICES"] = str(gpus) # TODO: this doesn't support multi-gpu
            proc = subprocess.Popen(command_list, env=env)
            processes.append(proc)
            print("[GRID SWEEPER] Started job", i+1, "on gpus:", gpus, "saving to", os.path.join(self.path, name))

        try:
            exit_codes = [p.wait() for p in processes]
            print("[GRID SWEEPER] Waiting for completion.")
        except KeyboardInterrupt:
            for p in processes:
                try:
                    p.terminate()
                except OSError:
                    pass
                p.wait()
