#!/usr/bin/env python3
import argparse
import prompt_toolkit
from prompt_toolkit import prompt
from pathlib import Path
from uimnet import utils
from uimnet import measures
import concurrent.futures
import pickle
import pandas as pd
import tabulate
import yaml
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

MAX_WORKERS = None

KEYS = ('stage', 'model', 'k', 'state', 'query')
STAGES = ['train', 'calibration', 'prediction'] + \
  [f'ood_{el}' for el in measures.__MEASURES__.keys()]

def parse_arguments():
  parser = argparse.ArgumentParser(description='Check completions of sweeps stages')
  parser.add_argument('-s', '--sweep_dir', type=str, required=True)

  return parser.parse_args()

def main(args):
  sweep_path = Path(args.sweep_dir)
  loop(sweep_path)
  return

def _help():
  return f"""
  help: prints help.
  q: quits.
  summarize: prints summary.
  command: Of the form {KEYS}. Replace key by 'all' to select all
  """

def get_df(records):
  df = pd.DataFrame.from_records(records)
  grouped = df.groupby(['stage', 'model', 'k', 'state'])['path'].agg(count='count', path=lambda el: list(el))
  return grouped

def parse_command(string):
  commands = dict(zip(KEYS, string.split(' ')))
  query = commands.pop('query')

  idx = ()
  for command in commands.values():
    idx += (command if command != 'all' else slice(None), )

  return idx, query


def loop(sweep_path):
  while True:
    prompt_msg = '>'
    command = prompt(prompt_msg)
    records = check_sweep_stages(sweep_path)
    df = get_df(records)
    print(command)
    try:
      if command == 'help':
        print(_help())
      elif command == 'summarize':
        print(df)
      elif command == 'q':
        break
      else:
        idx, query = parse_command(command)
        print(df.loc[idx][query])
    except Exception as e:
      print(e)
      print(f'Invalid command {command}')
      print(e)
      continue
  return


def check_model_stages(model_path):

  records = []
  with open(model_path / 'train_cfg.yaml', 'r') as fp:
    train_cfg = yaml.safe_load(fp)
  model_name = train_cfg['algorithm']['name']
  k = str(train_cfg['ensemble']['k'] if 'ensemble' in train_cfg else 1)

  for stage in STAGES:
    trace_path = model_path / f'{stage}.state'
    if trace_path.exists():
      with open(trace_path, 'r') as fp:
        state = fp.read()
    else:
      state = 'missing'
    record = dict(
      model=model_name,
      k=k,
      path=str(model_path),
                  stage=stage,
                  state=state)
    records += [record]
  return records

def check_sweep_stages(sweep_path):

  subpaths = [el for el in sweep_path.iterdir() if el.is_dir()]
  def _is_model_path(path):
    return (path /  '.algorithm').exists() or (path / '.ensemble').exists()
  models_paths = filter(_is_model_path, subpaths)

  with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    all_records = sum(list(executor.map(check_model_stages, models_paths)), [])

  return all_records


if __name__ == '__main__':
  args = parse_arguments()
  all_records = main(args)
