#!/usr/bin/env python3
import time
import itertools
import functools
import uuid
from pathlib import Path
import pickle
import os

from uimnet.algorithms.base import Algorithm

import submitit
import submitit.core

import copy
from uimnet import utils
from uimnet import workers
from uimnet import algorithms
from uimnet import ensembles
from uimnet import __SLURM_CONFIGS__, __DEBUG__

from omegaconf.omegaconf import OmegaConf


__PLATFORMS__ = ['debug', 'local', 'slurm']

def check_platform(platform):
  if platform not in __PLATFORMS__:
    err_msg = f'Urecognized dispatch platform {platform}'
    raise ValueError(err_msg)
  return True



def dispatch_args(worker, cfg, logs_path, args=None, map_array=True, **kwargs):

  _platform = cfg.experiment.platform
  check_platform(_platform)

  if _platform == 'debug':
    if args is None:
      return utils.pack(worker(cfg, **kwargs))
    else:
      _args = args[0]
      return utils.pack(worker(*_args))

  executor = {'local': submitit.AutoExecutor,
              'slurm': submitit.SlurmExecutor}[_platform](folder=logs_path)

  executor.update_parameters(
    **__SLURM_CONFIGS__[cfg.slurm.preset]
  )
  executor.update_parameters(
    time=cfg.slurm.time,
    mem_per_gpu=cfg.slurm.mem_per_gpu,
    cpus_per_task=cfg.slurm.cpus_per_task,
    partition=cfg.slurm.partition,
    comment=cfg.slurm.comment if 'comment' in cfg.slurm else 'Yippee ki-yay'
  )

  if args is None:
    jobs = utils.pack(executor.submit(worker, cfg, **kwargs))
  elif map_array:
    jobs = executor.map_array(worker, *zip(*args))
  else:
    err_msg = f'Specify args when using map_array'
    raise ValueError(err_msg)

  data = []
  for job in jobs:
    try:
      data += [job.results()]
    except submitit.core.utils.FailedSubmissionError as e:
      # TODO(XXX): Resubmit failed jobs
      utils.message(e)
    except submitit.core.utils.FailedJobError as e:
      utils.message(e)

  return data


def submit(cfg, output_dir, worker, **kwargs):

  elapsed_time = time.time()
  output_path = Path(output_dir)
  output_path.mkdir(parents=True, exist_ok=True)

  with open(output_path / 'cfg.yaml', 'w') as fp:
    OmegaConf.save(cfg, f=fp.name)

  logs_path = output_path / 'logs' / 'submit'
  logs_path.mkdir(parents=True, exist_ok=True)
  data = dispatch_args(worker, cfg, logs_path=logs_path, map_array=False, **kwargs)

  return dict(
    data=data,
    output_dir=str(output_path.absolute()),
    cfg=cfg,
    elapsed_time=time.time() - elapsed_time)


@utils.checkpoint()
def evaluate_ood(eval_cfg, sweep_dir, partitions):

  elapsed_time = time.time()
  sweep_path = Path(sweep_dir)
  subpaths = [el for el in sweep_path.iterdir() if el.is_dir()]

  all_datasets = {}
  for split_name in ['train', 'val']:
    all_datasets[split_name] = utils.partition_dataset(name=eval_cfg.dataset.name,
                                        root=eval_cfg.dataset.root,
                                        split=split_name,
                                        partitions=partitions,
                                        equalize_partitions=eval_cfg.dataset.equalize_partitions)

  partitions_are_consistent = all_datasets['train'].keys() == all_datasets['val'].keys()
  if not partitions_are_consistent:
    err_msg = f'datasets don\' have the same partitions'
    raise RuntimeError(err_msg)

  unfinished_training = 0
  filtered_subpaths = []
  for subpath in filter(utils.is_valid_subpath, subpaths):
    if not utils.is_valid_subpath(subpath):
      unfinished_training += 1
      utils.message(f'Unfinished training at {subpath}.')
      continue
    filtered_subpaths += [subpath]
  utils.message(f'{unfinished_training} models have not finished training.')

  args = []
  for subpath in filtered_subpaths:
    for partition_name in all_datasets['train']:
      for temperature_mode in ['initial', 'learned']:
        if partition_name == 'difficult':
          continue

        if utils.trace_exists('ood_evaluation.done', dir_=str(subpath)) and not int(os.getenv('FORCE_STAGE', 0)):
          continue

        # Loading train config
        filename = 'cfg_rank_0.yaml'
        with open(subpath / filename, 'r') as fp:
          train_cfg = OmegaConf.load(fp.name)

        # loading algorithm
        Algorithm = utils.load_model_cls(train_cfg=train_cfg)

        _eval_cfg = copy.deepcopy(eval_cfg)
        OmegaConf.set_readonly(_eval_cfg, False)
        _eval_cfg.dataset.partition = partition_name
        _eval_cfg.temperature_mode = temperature_mode
        OmegaConf.set_readonly(_eval_cfg, True)

        train_dataset = all_datasets['train'][partition_name]
        val_dataset = all_datasets['val'][partition_name]
        args.append([_eval_cfg, train_cfg, Algorithm, train_dataset, val_dataset])
        utils.write_trace('ood_evaluation.pending', dir_=str(subpath))

  logs_path = sweep_path / 'logs' / 'eval'
  logs_path.mkdir(parents=True, exist_ok=True)
  data = dispatch_args(workers.Evaluator(),
                       eval_cfg,
                       args=args,
                       logs_path=logs_path,
                       map_array=True
                       )

  return dict(data=data,
              sweep_dir=sweep_path,
              elapsed_time=time.time() - elapsed_time
              )




def train_ensembles(ensembles_cfg, sweep_dir, partitions, ensembles_paths, train_on_in_only=True):

  elapsed_time = time.time()

  # Algorithms/dataseed/list of paths
  sweep_path = Path(sweep_dir)
  _datasets = utils.partition_dataset(name=ensembles_cfg.dataset.name,
                                      root=ensembles_cfg.dataset.root,
                                      split=ensembles_cfg.dataset.split,
                                      partitions=partitions,
                                      equalize_partitions=ensembles_cfg.dataset.equalize_partitions)

  Ensembles = [ensembles.__dict__[name] for name in ensembles_cfg.ensemble.name]

  args = []
  for dataset in _datasets.values():
    if train_on_in_only and dataset.name != 'in':
      continue

    for (algorithm_name, algorithm_arch, sn), dataset_seed_paths in ensembles_paths.items():
      for data_seed, paths in dataset_seed_paths.items():
        for Ensemble in Ensembles:
          for k in ensembles_cfg.ensemble.k:

            k = min(k, len(paths))
            _paths = paths[:k]

            _ensembles_cfg = copy.deepcopy(ensembles_cfg)

            ensemble_name = Ensemble.__name__
            ensemble_dir = f'{ensemble_name}_{algorithm_name}_{algorithm_arch}_{sn}_{data_seed}_{k}'
            ensemble_path = sweep_path / ensemble_dir

            ensemble_path.mkdir(parents=True, exist_ok=True)
            if utils.trace_exists('train.done', dir_=str(ensemble_path)) and not int(os.getenv('FORCE_STAGE', 0)):
              continue

            OmegaConf.set_readonly(_ensembles_cfg, False)
            # Serializing paths
            with open(ensemble_path / 'paths.pkl', 'wb') as fp:
              pickle.dump(_paths, fp, protocol=pickle.HIGHEST_PROTOCOL)

            _ensembles_cfg.dataset.name = dataset.name
            _ensembles_cfg.dataset.seed = data_seed
            _ensembles_cfg.output_dir = str(ensemble_path.absolute())

            _ensembles_cfg.ensemble.name = ensemble_name
            _ensembles_cfg.ensemble.k = k

            _ensembles_cfg.algorithm.sn = sn
            _ensembles_cfg.algorithm.name = algorithm_name
            _ensembles_cfg.algorithm.arch = algorithm_arch

            OmegaConf.set_readonly(_ensembles_cfg, True)
            # Use partial evaluation to make Ensemble have the same
            # signature as Algorithm
            Algorithm = functools.partial(
              Ensemble, paths=_paths)
            args.append([_ensembles_cfg, Algorithm, dataset])
            utils.write_trace('train.pending', dir_=_ensembles_cfg.output_dir)



  logs_path = sweep_path / 'logs' / 'train_ensembles'
  logs_path.mkdir(parents=True, exist_ok=True)

  data = dispatch_args(workers.Trainer(),
                       ensembles_cfg,
                       args=args,
                       logs_path=logs_path,
                       map_array=True
                       )

  return dict(data=data,
              sweep_dir=sweep_path,
              elapsed_time=time.time() - elapsed_time
              )

def train(train_cfg, sweep_dir, datasets, train_on_in_only=True):

  elapsed_time = time.time()

  sweep_path = Path(sweep_dir)
  sweep_path.mkdir(parents=True, exist_ok=True)
  with open(Path(train_cfg.sweep_dir) / 'train_cfg.yaml', 'w') as fp:
    OmegaConf.save(train_cfg, f=fp.name)

  data_seeds = range(train_cfg.dataset.nseeds)
  Algorithms = [algorithms.__dict__[name] for name in train_cfg.algorithm.name]
  alg_seeds = range(train_cfg.algorithm.nseeds)
  sns = train_cfg.algorithm.sn

  args_iter = itertools.product(
    datasets.values(), data_seeds, Algorithms, alg_seeds, train_cfg.algorithm.arch, sns
  )

  args = []
  for (dataset, data_seed, Algorithm, alg_seed, arch, sn) in args_iter:

    if train_on_in_only and dataset.name != 'in':
      continue

    _train_cfg = copy.deepcopy(train_cfg)
    OmegaConf.set_readonly(_train_cfg, False)
    _train_cfg.dataset.name = dataset.name
    _train_cfg.dataset.seed = data_seed

    _train_cfg.algorithm.name = Algorithm.__name__
    _train_cfg.algorithm.arch = arch
    _train_cfg.algorithm.seed = alg_seed

    _train_cfg.algorithm.sn = sn

    output_dir = f'{Algorithm.__name__}_{dataset.name}_{arch}_{sn}_{data_seed}_{alg_seed}'
    output_path = sweep_path.absolute() / output_dir
    output_path.mkdir(parents=True, exist_ok=True)
    _train_cfg.output_dir = str(output_path)
    OmegaConf.set_readonly(_train_cfg, True)

    if not utils.trace_exists('train.done', dir_=str(output_path)) and not int(os.getenv('FORCE_STAGE', 0)):
      args.append([_train_cfg, Algorithm, dataset])
      utils.write_trace('train.pending', dir_=_train_cfg.output_dir)
      with open(output_path / 'train_cfg.yaml', 'w') as fp:
        OmegaConf.save(_train_cfg, f=fp.name)

  logs_path = sweep_path / 'logs' / 'train'
  logs_path.mkdir(parents=True, exist_ok=True)

  data = dispatch_args(workers.Trainer(),
                       train_cfg,
                       args=args,
                       logs_path=logs_path,
                       map_array=True
                       )

  return dict(data=data,
              sweep_dir=sweep_path,
              elapsed_time=time.time() - elapsed_time
              )


def estimate_gaussian_mixtures(mog_cfg, sweep_dir, datasets):

  elapsed_time = time.time()
  sweep_path = Path(sweep_dir)
  subpaths = [el for el in sweep_path.iterdir() if el.is_dir()]

  unfinished_training = 0
  filtered_subpaths = []
  for subpath in filter(utils.is_valid_subpath, subpaths):
    if not utils.is_valid_subpath(subpath):
      unfinished_training += 1
      utils.message(f'Unfinished training at {subpath}')
      continue
    filtered_subpaths += [subpath]
  utils.message(f'{unfinished_training} models have not finished training')

  args = []
  for subpath in filtered_subpaths:
    if utils.trace_exists('mog.done', dir_=str(subpath)) and not int(os.getenv('FORCE_STAGE', 0)):
      continue
    # Loading train config
    filename = 'cfg_rank_0.yaml'

    _mog_cfg = copy.deepcopy(mog_cfg)
    with open(subpath / filename, 'r') as fp:
      train_cfg = OmegaConf.load(fp.name)
    # loading algorithm
    Algorithm = utils.load_model_cls(train_cfg=train_cfg)
    if hasattr(train_cfg, 'ensemble'):
      utils.message(f'Skipping ensemble')
      continue
    dataset = datasets['in']
    args.append([_mog_cfg, train_cfg, Algorithm, dataset])
    utils.write_trace('mog.pending', dir_=str(subpath))

  logs_path = sweep_path / 'logs' / 'mog'
  logs_path.mkdir(parents=True, exist_ok=True)
  data = dispatch_args(workers.MOG(),
                      mog_cfg,
                      args=args,
                      logs_path=logs_path,
                      map_array=True
                      )

  return dict(data=data,
              sweep_dir=sweep_path,
              elapsed_time=time.time() - elapsed_time
              )


def calibrate(calibration_cfg, sweep_dir, datasets):

  elapsed_time = time.time()
  sweep_path = Path(sweep_dir)
  subpaths = [el for el in sweep_path.iterdir() if el.is_dir()]

  unfinished_training = 0
  filtered_subpaths = []
  for subpath in filter(utils.is_valid_subpath, subpaths):
    if not utils.is_valid_subpath(subpath):
      unfinished_training += 1
      utils.message(f'Unfinished training at {subpath}')
      continue
    filtered_subpaths += [subpath]
  utils.message(f'{unfinished_training} models have not finished training')

  args = []
  for subpath in filtered_subpaths:
    # Loading train config
    if utils.trace_exists('calibration.done', dir_=str(subpath)) and not int(os.getenv('FORCE_STAGE', 0)):
      continue

    filename = 'cfg_rank_0.yaml'

    _calibration_cfg = copy.deepcopy(calibration_cfg)
    with open(subpath / filename, 'r') as fp:
      train_cfg = OmegaConf.load(fp.name)
    # loading algorithm
    Algorithm = utils.load_model_cls(train_cfg=train_cfg)
    dataset = datasets['in']
    args.append([_calibration_cfg, train_cfg, Algorithm, dataset])
    utils.write_trace('calibration.pending', dir_=str(subpath))

  logs_path = sweep_path / 'logs' / 'calibration'
  logs_path.mkdir(parents=True, exist_ok=True)
  data = dispatch_args(workers.Calibrator(),
                      calibration_cfg,
                      args=args,
                      logs_path=logs_path,
                      map_array=True
                      )

  return dict(data=data,
              sweep_dir=sweep_path,
              elapsed_time=time.time() - elapsed_time
              )

if __name__ == '__main__':
  pass
