# coding=utf-8
# Copyright 2023 The Uncertainty Baselines Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to store and load results generated by the evaluation.py script."""
# pylint: disable=g-bare-generic
# pylint: disable=g-doc-args
# pylint: disable=g-doc-return-or-yield
# pylint: disable=g-importing-member
# pylint: disable=g-no-space-after-docstring-summary
# pylint: disable=g-short-docstring-punctuation
# pylint: disable=logging-format-interpolation
# pylint: disable=logging-fstring-interpolation
# pylint: disable=missing-function-docstring
import collections
import json
import os
import pathlib
import pdb
import pickle
from typing import Dict, List, Union
from typing import Optional
from typing import Tuple
from absl import logging
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm

JOINT_SPLIT_TO_CONSTITUENT_SPLITS = {
    'joint_validation': ['in_domain_validation', 'ood_validation'],
    'joint_test': ['in_domain_test', 'ood_test']
}


def merge_and_store_scalar_results(scalar_results_list: List[Dict],
                                   output_dir,
                                   allow_overwrite=True):
  keys = scalar_results_list[0].keys()
  total_results = collections.defaultdict(list)

  for key in keys:
    for scalar_results_dict in scalar_results_list:
      val = scalar_results_dict[key]
      if val is None:
        continue
      total_results[key].append(val)

  metric = []
  mean = []
  var = []
  stderr = []
  for key, values in total_results.items():
    metric.append(key)
    values = np.array(values)
    mean.append(values.mean())
    var.append(values.std()**2)
    stderr.append(values.std() / np.sqrt(len(values)))

  total_results_df = pd.DataFrame(data={
      'metric': metric,
      'mean': mean,
      'var': var,
      'stderr': stderr
  })
  store_dataframe_gfile(
      output_dir,
      'scalar_results.tsv',
      total_results_df,
      allow_overwrite=allow_overwrite)


def save_per_prediction_results(output_dir,
                                epoch,
                                per_prediction_results,
                                verbose=True,
                                allow_overwrite=True):
  for dataset_key, results_dict in per_prediction_results.items():
    if verbose:
      logging.info(
          f'Storing per-prediction metrics for dataset split {dataset_key}.')
      logging.info(f'Keys: {list(results_dict.keys())}')

    dataset_output_dir = os.path.join(output_dir, dataset_key)
    store_eval_results(
        dataset_output_dir,
        results_dict,
        epoch=epoch,
        allow_overwrite=allow_overwrite)


def add_joint_dicts(dataset_split_containers: Dict[str, Dict],
                    is_deterministic):
  """If we have in_domain and ood validation or test sets, construct a 'joint'

  set of predictions, ground truths, uncertainty values, etc. for convenience.
  """
  for dataset_split_dict in dataset_split_containers.values():
    dataset_split_dict['ms_per_example'] = (
        dataset_split_dict['total_ms_elapsed'] /
        dataset_split_dict['dataset_size'])

  for joint_split, constituent_splits in (
      JOINT_SPLIT_TO_CONSTITUENT_SPLITS.items()):
    should_construct_split = True
    for constituent_split in constituent_splits:
      if constituent_split not in dataset_split_containers.keys():
        should_construct_split = False
        break

    if should_construct_split:
      joint_split_dict = {}
      # Sanity check: this is set in the constants.py file
      assert 'in_domain' in constituent_splits[0]
      assert 'ood' in constituent_splits[1]
      in_domain_dict = dataset_split_containers[constituent_splits[0]]
      ood_dict = dataset_split_containers[constituent_splits[1]]

      # Sanity check: we actually produced all expected predictions
      id_len = in_domain_dict['dataset_size']
      ood_len = ood_dict['dataset_size']
      assert id_len == len(in_domain_dict['y_true'])
      assert ood_len == len(ood_dict['y_true'])

      # Concatenate and add all list results
      entries_to_concat = [
          'y_true',
          'y_pred',
          'y_pred_entropy',
      ]

      if not is_deterministic:
        entries_to_concat += [
            'y_pred_variance', 'y_aleatoric_uncert', 'y_epistemic_uncert'
        ]

      for entry in entries_to_concat:
        joint_split_dict[entry] = np.concatenate(
            (in_domain_dict[entry], ood_dict[entry]))

      joint_split_dict['dataset_size'] = id_len + ood_len

      # Add an is_ood flag
      joint_split_dict['is_ood'] = ([False] * id_len) + ([True] * ood_len)

      # Add the overall milliseconds per example
      joint_split_dict['ms_per_example'] = (
          (in_domain_dict['total_ms_elapsed'] + ood_dict['total_ms_elapsed']) /
          (in_domain_dict['dataset_size'] + ood_dict['dataset_size']))

      dataset_split_containers[joint_split] = joint_split_dict

  return dataset_split_containers


def get_eval_run_str(model_type,
                     dataset_key,
                     eval_seed,
                     train_seed: Optional[int] = None,
                     k: Optional[int] = None):
  if train_seed:
    eval_run_str = (f'{model_type}__{dataset_key}__evalseed_{eval_seed}'
                    f'__trainseed_{train_seed}')
  elif k:
    eval_run_str = (f'{model_type}__{dataset_key}__evalseed{eval_seed}'
                    f'__k{k}')
  else:
    raise ValueError(
        'Must provide either a train_seed or the size of the ensemble k.')

  return eval_run_str


def create_eval_results_dir(output_dir,
                            dataset_key,
                            model_type,
                            date_time,
                            eval_seed,
                            train_seed: Optional[int] = None,
                            k: Optional[int] = None):
  """For each unique combination of

    (dataset_key, model_type, train_seed, eval_seed, date_time)
      in the case of a single model, or
    (dataset_key, model_type, k, eval_seed, date_time)
      for an ensemble,
  we will create an evaluation results directory.

  For an ensemble, you should pass the size of the ensemble k, and can omit
  the train seed (the list of all train seeds will be stored in the metadata).
  """
  eval_run_str = get_eval_run_str(model_type, dataset_key, eval_seed,
                                  train_seed, k)
  eval_results_dir = os.path.join(output_dir, eval_run_str)
  eval_results_dir = os.path.join(eval_results_dir, date_time)
  try:
    pathlib.Path(eval_results_dir).mkdir(parents=True, exist_ok=False)
    logging.info(f'Created eval results dir: {eval_results_dir}')
  except FileExistsError:
    pass

  return eval_results_dir


def store_dataframe_gfile(folder, file_name, df_to_store, allow_overwrite=True):
  file_path = os.path.join(folder, file_name)
  if not allow_overwrite and tf.io.gfile.exists(file_path):
    raise ValueError(f'{file_path} exists already!!!')
  with tf.io.gfile.GFile(file_path, 'w') as f:
    df_to_store.to_csv(path_or_buf=f, sep='\t', index=None)
  logging.info(f'Stored {file_name} to {folder}')


def load_dataframe_gfile(file_path, sep='\t'):
  with tf.io.gfile.GFile(file_path, 'r') as f:
    df = pd.read_csv(f, sep=sep)

  logging.info(f'Loaded dataframe from {file_path}.')
  return df


def store_json_gfile(folder, file_name, dict_to_store):
  file_path = os.path.join(folder, file_name)
  with tf.io.gfile.GFile(file_path, 'w') as f:
    json.dump(dict_to_store, f)
  logging.info(f'Stored {file_name} to {folder}')


def store_eval_metadata(eval_results_dir,
                        ms_per_example,
                        k: Optional[int] = None):
  eval_metadata_path = os.path.join(eval_results_dir, 'metadata.json')
  eval_metadata = {'ms_per_example': ms_per_example, 'k': k}
  with tf.io.gfile.GFile(eval_metadata_path, 'w') as f:
    json.dump(eval_metadata, f)
  logging.info(f'Stored eval metadata to {eval_metadata_path}')


def store_eval_results(eval_results_dir,
                       dict_of_lists,
                       epoch=None,
                       allow_overwrite=True):
  """Store image names, predictions, ground truth, uncertainty estimates,

  and optionally, an array with binary indicators of whether or not the
  prediction is OOD (`is_ood`).
  """

  if epoch is None:
    eval_results_name = 'eval_results'
  else:
    eval_results_name = f'eval_results_{epoch}'

  eval_results_dir = os.path.join(eval_results_dir, eval_results_name)

  tf.io.gfile.makedirs(eval_results_dir)
  assert tf.io.gfile.isdir(eval_results_dir)

  for key, arr in dict_of_lists.items():
    np_eval_results_path = os.path.join(eval_results_dir, f'{key}.npy')
    if not allow_overwrite and tf.io.gfile.exists(np_eval_results_path):
      raise ValueError(f'The file {np_eval_results_path} exists already!!!')
    with tf.io.gfile.GFile(np_eval_results_path, 'w') as f:
      np.save(f, np.array(arr))

  logging.info(f'Stored eval results to {eval_results_dir}')


def load_eval_results(eval_results_dir, epoch=None, name_filter=None):
  if epoch is None:
    eval_results_name = 'eval_results'
  else:
    eval_results_name = f'eval_results_{epoch}'

  eval_results_dir = os.path.join(eval_results_dir, eval_results_name)

  arr_names = tf.io.gfile.listdir(eval_results_dir)
  if name_filter:
    arr_names = list(filter(name_filter, arr_names))
  eval_results = {}
  for arr_name in arr_names:
    np_eval_results_path = os.path.join(eval_results_dir, arr_name)
    with tf.io.gfile.GFile(np_eval_results_path, 'rb') as f:
      arr = np.load(f, allow_pickle=True)
      eval_results[arr_name.split('.')[0]] = arr

  logging.info(f'Loaded eval results from {eval_results_dir}')
  return eval_results


def cache_eval_results(output_dir, metadata_df, results_df):
  """A results entry is uniquely identified by:

    model_type
    train_seed
    eval_seed
    run_datetime
  """
  metadata_path = os.path.join(output_dir, 'metadata.tsv')
  results_path = os.path.join(output_dir, 'results.tsv')

  for storage_type, df, path in zip(['metadata', 'results'],
                                    [metadata_df, results_df],
                                    [metadata_path, results_path]):
    # Update or initialize results DataFrame
    try:
      with tf.io.gfile.GFile(path, 'r') as f:
        previous_df = pd.read_csv(f, sep='\t')
        df = pd.concat([previous_df, df])
        action_str = 'updated'
    except (FileNotFoundError, tf.errors.NotFoundError):
      logging.info(f'No previous {storage_type} found at path {path}. '
                   f'Storing a new {storage_type} dataframe.')
      action_str = 'stored initial'

    # Store to file
    with tf.io.gfile.GFile(path, 'w') as f:
      df.to_csv(path_or_buf=f, sep='\t', index=False)

    logging.info(
        f'Successfully {action_str} {storage_type} dataframe at {path}.')


def get_results_from_model_dir(model_dir: str):
  """Get results from a subdir.

  Args:
    model_dir: `str`, subdirectory that contains a `results.tsv` file for the
      corresponding model type.

  Returns:
    Results pd.DataFrame, or None, from a subdirectory containing
    results from a particular model type run on deferred prediction.
  """
  model_results_path = os.path.join(model_dir, 'results.tsv')
  if not tf.io.gfile.exists(model_results_path):
    return None

  logging.info('Found results at %s.', model_results_path)

  with tf.io.gfile.GFile(model_results_path, 'r') as f:
    return pd.read_csv(f, sep='\t')


def load_directory_results(results_dir: str, model_type: Optional[str] = None):
  """Load evaluation results from the specified directory.

  Args:
    results_dir: `str`, directory from which evaluation results are loaded. If
      you aim to generate an evaluation for multiple models, this should point
      to a directory in which each subdirectory has a name corresponding to a
      model type. Otherwise, if you aim to generate a plot for a particular
      model, `results_dir` should point directly to a model type's subdirectory,
      and the `model_type` argument should be provided.
    model_type: `str`, should be provided if generating a plot for only one
      particular model.
  """
  if model_type is None:
    dir_path, child_dir_suffixes, _ = next(tf.io.gfile.walk(results_dir))
    model_dirs = []
    results_dfs = []

    for child_dir_suffix in child_dir_suffixes:
      try:
        model_type = child_dir_suffix.split('/')[0]
      except:  # pylint: disable=bare-except
        continue

      model_dir = os.path.join(dir_path, model_type)
      logging.info(
          f'Found results directory for model {model_type} at {model_dir}.')
      model_dirs.append(model_dir)

    for model_dir in model_dirs:
      results = get_results_from_model_dir(model_dir)
      if results is not None:
        results_dfs.append(results)

    results_df = pd.concat(results_dfs, axis=0)
  else:
    logging.info(
        f'Plotting deferred prediction results for model type {model_type}.')
    model_results_path = os.path.join(results_dir, 'results.tsv')
    try:
      with tf.io.gfile.GFile(model_results_path, 'r') as f:
        results_df = pd.read_csv(f, sep='\t')
    except (FileNotFoundError, tf.errors.NotFoundError):
      raise FileNotFoundError(f'No results found at path {model_results_path}.')

  return results_df


def load_dataset_dir(base_path, dataset_subdir):
  results = collections.defaultdict(list)
  dataset_subdir_path = os.path.join(base_path, dataset_subdir)
  random_seed_dirs = tf.io.gfile.listdir(dataset_subdir_path)
  seeds = [
      int(random_seed_dir.split('_')[-1].split('/')[0])
      for random_seed_dir in random_seed_dirs
  ]
  seeds = sorted(seeds)
  for seed in tqdm(seeds, desc='loading seed results...', disable=True):
    eval_results = load_eval_results(
        eval_results_dir=dataset_subdir_path, epoch=seed)
    for arr_name, arr in eval_results.items():
      if arr.ndim > 0 and arr.shape[0] > 1:
        results[arr_name].append(arr)
  return results


def load_list_datasets_dir(base_path):
  dataset_results = {}
  dataset_subdirs = [
      file_or_dir for file_or_dir in tf.io.gfile.listdir(base_path)
      if tf.io.gfile.isdir(os.path.join(base_path, file_or_dir))
  ]

  for dataset_subdir in tqdm(
      dataset_subdirs, desc='loading datasets results..', disable=True):
    dataset_name = dataset_subdir.strip('/')
    logging.info(dataset_name)
    dataset_results[dataset_name] = load_dataset_dir(
        base_path=base_path, dataset_subdir=dataset_subdir)
  return dataset_results


def load_model_dir_result_with_cache(model_dir_path,
                                     cache_file_name='cache',
                                     invalid_cache=False):
  cache_path = os.path.join(model_dir_path, cache_file_name)
  if tf.io.gfile.exists(cache_path) and not invalid_cache:
    logging.info(f'Reading cache from {cache_path}')
    with tf.io.gfile.GFile(cache_path, 'rb') as f:
      dataset_results = pickle.load(f)
  else:
    # Tuning domain is either `indomain`, `joint` in our implementation.
    # not using lambda otherwise it is not possible to pickle
    eval_types = [
        agg for agg in tf.io.gfile.listdir(model_dir_path)
        if tf.io.gfile.isdir(os.path.join(model_dir_path, agg))
    ]
    dataset_results = {}
    for eval_type in tqdm(eval_types):
      dataset_results[eval_type] = load_list_datasets_dir(
          os.path.join(model_dir_path, eval_type))
    if dataset_results:
      logging.info(
          f"{model_dir_path} is empty directory, won't create cache file.")
      return {}
    logging.info(f'Caching result in {model_dir_path} in file {cache_path}...')
    with tf.io.gfile.GFile(cache_path, 'wb') as f:
      pickle.dump(dataset_results, f)

  dataset_results = {k.strip('/'): v for k, v in dataset_results.items()}
  return dataset_results


def parse_model_dir_name_v1(model_dir: str) -> Tuple:
  try:
    model_type, ensemble_str, tuning_domain, mc_str = model_dir.split('_')
  except:
    raise ValueError('Expected model directory in format '
                     '{model_type}_k{k}_{tuning_domain}_mc{n_samples}')
  k = int(ensemble_str[1:])  # format f'k{k}'
  num_mc_samples = mc_str[2:][:-1]  # format f'mc{num_mc_samples}/'
  is_deterministic = model_type == 'deterministic' and k == 1
  key = (model_type, k, is_deterministic, tuning_domain, num_mc_samples)
  return key


def parse_model_dir_name(model_name: str, eval_type: str) -> Union[Tuple, str]:
  version = decide_model_name_version(model_name)
  if version == 1:
    parsed = parse_model_dir_name_v1(model_name)
    k = {'single': 1, 'ensemble': 3}[eval_type.strip('/')]
    parsed = parsed[:1] + (k,) + parsed[2:]
  elif version == 2:
    parsed = model_name.strip('/') + f"__eval-type_{eval_type.strip('/')}"
  else:
    raise NotImplementedError(version)
  return parsed


def model_name_v2_to_dict(model_name: str) -> Dict:
  d = collections.OrderedDict()
  for pair in model_name.split('__'):
    k, v = pair.split('_')
    if k in d:
      pdb.set_trace()
    d[k] = v
  return d


def decide_model_name_version(model_name: str) -> int:
  if '__' in model_name:
    return 2
  else:
    return 1


def fast_load_dataset_to_model_results(results_dir,
                                       model_dir_cache_file_name='cache',
                                       invalid_cache=False):
  dataset_to_model_results = collections.defaultdict(
      lambda: collections.defaultdict(lambda: collections.defaultdict(list)))

  model_dirs = tf.io.gfile.listdir(results_dir)
  for model_dir in tqdm(model_dirs, desc='loading model results...'):
    model_dir_path = os.path.join(results_dir, model_dir)
    model_result = load_model_dir_result_with_cache(
        model_dir_path=model_dir_path,
        cache_file_name=model_dir_cache_file_name,
        invalid_cache=invalid_cache,
    )
    for eval_type, eval_dict in model_result.items():
      for dataset, array_dict in eval_dict.items():
        parsed_model_key = parse_model_dir_name(model_dir, eval_type=eval_type)
        if parsed_model_key in dataset_to_model_results[dataset]:
          raise ValueError(
              'Already have keys {}.'.format(
                  dataset_to_model_results[dataset].keys()))
        dataset_to_model_results[dataset][parsed_model_key] = array_dict

  return dataset_to_model_results
