#!/usr/bin/env python3
from pathlib import Path
import itertools
import functools
import argparse
import concurrent.futures
import pickle
import pandas as pd
import tabulate
import torch
from uimnet import utils
from uimnet import measures

def parse_arguments():
  parser = argparse.ArgumentParser(description='Consolidates records')
  parser.add_argument('-s', '--sweep_dir', type=str, required=True)
  parser.add_argument('-o', '--output', type=str, required=True)
  parser.add_argument('--max_workers', type=int, default=None)

  return parser.parse_args()


def _post_process_record(record):
  if 'ensemble.k' in record:
    return record
  record['ensemble.k'] = 1
  return record

def _is_not_none(el):
  return not (el is None)

@utils.timeit
def collect_outdomain_records(models_paths, Measures, max_workers):

  records_filepaths = []
  for model_path, Measure in itertools.product(models_paths, Measures):
    trace = f'ood_{Measure.__name__}'
    if utils.trace_exists(f'{trace}.done', dir_=str(model_path)):
      records_filepaths += [model_path / f'{Measure.__name__}_results.pkl']

  # with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
     all_records = list(executor.map(utils.ExtractRecords(), records_filepaths))

  all_records = sum(list(filter(_is_not_none, all_records)), [])
  all_records = list(map(_post_process_record, all_records))
  return all_records

@utils.timeit
def collect_indomain_records(models_paths, max_workers):

  models_paths = map(lambda el: el / 'predictive_records.pkl', list(models_paths))
  with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
     all_records = list(executor.map(utils.ExtractRecords(), models_paths))
  all_records = sum(list(filter(_is_not_none, all_records)), [])
  # Removing missing records
  all_records = list(map(_post_process_record, all_records))
  return all_records


def collect_records(sweep_dir, max_workers):
  sweep_path = Path(sweep_dir)
  subdirs = [el for el in sweep_path.iterdir() if el.is_dir()]
  models_paths = list(filter(utils.is_model, subdirs))

  records = {}
  records['indomain'] = collect_indomain_records(models_paths,
                                               max_workers=max_workers
                                               )
  records['outdomain'] = collect_outdomain_records(models_paths,
                                                   Measures=list(measures.__MEASURES__.values()),
                                                   max_workers=max_workers
                                                )
  return records

def main(sweep_dir, output, max_workers):

  records = collect_records(sweep_dir, max_workers=max_workers)
  with open(output, 'wb') as fp:
    pickle.dump(records, fp, protocol=pickle.HIGHEST_PROTOCOL)
  return records



if __name__ == '__main__':
  args = parse_arguments()
  records = main(args.sweep_dir, args.output, max_workers=args.max_workers)
