"""
Run experiment trials defined in config files.
"""

import argparse
import dataclasses
import hashlib
import json
import typing as tp
from datetime import datetime

import flatten_dict
import hydra.utils
import numpy as np
import pandas as pd
import ray
import upath
from loguru import logger
from omegaconf import OmegaConf
from tqdm.auto import tqdm

from generative_prediction_sets import data_loaders, sampling_metrics, utils
from generative_prediction_sets.core import Calibrator
from generative_prediction_sets.data_loaders import DatasetCachingUtility, HiddenStateDataset

T = tp.TypeVar("T")


def parse_config(cfg, schema: tp.Type[T]) -> T:
  d = OmegaConf.structured(schema)
  d = OmegaConf.merge(cfg, d)
  return tp.cast(T, OmegaConf.to_object(d))


def resolve_method_cfg(method_cfg, experiment_cfg):
  resolved_methods = OmegaConf.merge(
    OmegaConf.create(
      {
        "config": experiment_cfg,
      }
    ),
    method_cfg,
  )

  OmegaConf.resolve(resolved_methods)
  resolved_methods.pop("config")
  return resolved_methods


def resolve_all_method_configs_for_experiment(method_cfg, experiment_cfg):
  resolved_methods = []
  for k, v in method_cfg.items():
    resolved_methods.append(resolve_method_cfg(v, experiment_cfg))
  return resolved_methods


@dataclasses.dataclass
class ExperimentConfig:
  cfg: dict
  infer: str
  splits: str
  train: str

  @property
  def task(self) -> str:
    return self.cfg["task"]


@dataclasses.dataclass
class MethodConfig:
  calibrator: dict
  input_key: str

  def get_hash(self) -> str:
    """
    Create a stable hash for a MethodConfig instance.

    Args:
        config (MethodConfig): The MethodConfig instance to hash.

    Returns:
        str: A stable hash string for the MethodConfig.
    """
    # Convert the dataclass to a dictionary
    config_dict = dataclasses.asdict(self)

    # Sort the dictionary to ensure consistent ordering
    sorted_dict = json.dumps(config_dict, sort_keys=True)

    # Create a hash of the sorted JSON string
    return hashlib.sha256(sorted_dict.encode()).hexdigest()


def _ensure_2d(x: np.ndarray):
  if x.ndim == 1:
    return x.reshape((-1, 1))
  return x


def run_trial_for_experiment(
  Xs: np.ndarray,
  ys: np.ndarray,
  calibrator: Calibrator,
  *,
  alpha: float,
  test_size: float = 0.5,
  context: dict[str, np.ndarray] | None = None,
):
  Xs = _ensure_2d(Xs)
  indices = np.random.permutation(len(Xs))
  test_indices = indices[: int(len(Xs) * test_size)]
  cal_indices = indices[int(len(Xs) * test_size) :]
  cal_Xs = Xs[cal_indices]
  cal_ys = ys[cal_indices]
  test_Xs = Xs[test_indices]
  test_ys = ys[test_indices]

  # Compute conformal predictions
  results = calibrator(cal_Xs, cal_ys, test_Xs, alpha=alpha)

  # Compute metrics
  metrics = sampling_metrics.compute_metrics(
    test_ys,
    results.cis,
    metrics_list=[
      sampling_metrics.set_coverage,
      sampling_metrics.abstention_rate,
      sampling_metrics.base_model_abstention_rate,
      sampling_metrics.false_positive_abstention_rate,
      sampling_metrics.non_abstention_coverage,
      sampling_metrics.set_sizes,
      sampling_metrics.effective_set_sizes,
    ],
    context=context,
  )

  # Create results dataframe
  metrics_df = pd.DataFrame([metrics])
  metrics_df["alpha"] = alpha

  cal_trial_df = pd.DataFrame(
    flatten_dict.flatten(
      {
        "alpha": alpha,
        "cal_indices": cal_indices.tolist(),
        "quantile": results.quantile,
        "cal_preds": results.cal_preds.tolist(),
        "cal_aux": results.cal_aux,
        "cal_scores": results.cal_scores.tolist(),
      },
      reducer="dot",
    )
  )

  test_array_dict = flatten_dict.flatten(
    {
      "alpha": alpha,
      "test_indices": test_indices.tolist(),
      "test_preds": results.test_preds.tolist(),
      "test_aux": results.test_aux,
      "cis": results.cis.tolist(),  # cis are 2-d and pandas doesn't like it
      "metrics": metrics,
    },
    reducer="dot",
  )
  try:
    test_trial_df = pd.DataFrame(
      test_array_dict,
    )
  except Exception as e:
    print(
      "ERROR: SHAPES MISMATCH IN TEST: \n{}".format(
        {k: np.shape(v) for k, v in test_array_dict.items()}
      )
    )
    raise e from None

  return cal_trial_df, test_trial_df


@ray.remote(num_gpus=0.2)
def run_trials_for_method_and_alpha(
  train_Xs: np.ndarray,
  train_ys: np.ndarray,
  eval_Xs: np.ndarray,
  eval_ys: np.ndarray,
  context: dict[str, np.ndarray],
  method_config: MethodConfig,
  alpha: float,
  num_trials: int,
  seed: int,
):
  import warnings

  warnings.filterwarnings("ignore")
  utils.seed_everything(seed)
  calibrator = hydra.utils.instantiate(method_config.calibrator)
  calibrator.fit(_ensure_2d(train_Xs), train_ys)

  cal_trial_dfs = []
  test_trial_dfs = []
  for trial_idx in range(num_trials):
    cal_trial_df, test_trial_df = run_trial_for_experiment(
      eval_Xs, eval_ys, calibrator, alpha=alpha, context=context
    )

    cal_trial_df["trial_idx"] = trial_idx
    test_trial_df["trial_idx"] = trial_idx
    cal_trial_dfs.append(cal_trial_df)
    test_trial_dfs.append(test_trial_df)

  cal_trial_df = pd.concat(cal_trial_dfs)
  test_trial_df = pd.concat(test_trial_dfs)

  return cal_trial_df, test_trial_df


def run_single_experiment(
  experiment_config: ExperimentConfig,
  *,
  num_trials: int,
  method_configs,
  alphas,
  detail_path=None,
  debug: bool = False,
  data_cache_dir: str | None = None,
  context_map: dict[str, str] | None = None,
):
  import warnings
  from loguru import logger

  if debug:
    utils.set_log_level("DEBUG")
  else:
    utils.set_log_level("INFO")
  if data_cache_dir is not None:
    DatasetCachingUtility.set_base_cache_dir(data_cache_dir)
  if context_map is None:
    logger.warning("No context map provided, using empty context map")
    context_map = {}
  logger.info(f"Context map: {context_map}")
  warnings.filterwarnings("ignore")
  dataset = HiddenStateDataset(experiment_dir=experiment_config.infer).load()
  dataset = dataset.with_format("numpy")

  splits = data_loaders.load_splits(experiment_config.splits)
  train_ds = data_loaders.get_dataset_split(dataset, splits.train)
  eval_ds = data_loaders.get_dataset_split(dataset, splits.test)

  feature_columns = [
    method_config.input_key for method_config in method_configs.values()
  ]

  label_columns = ["admissible"]
  train_data = {k: np.array(train_ds[k]) for k in feature_columns + label_columns}
  eval_data = {k: np.array(eval_ds[k]) for k in feature_columns + label_columns}
  context_columns = {v: np.array(eval_ds[k]) for k, v in context_map.items()}

  futures = []
  for method_name, method_config in method_configs.items():
    for i, alpha in enumerate(alphas):
      futures.append(
        (
          method_name,
          run_trials_for_method_and_alpha.remote(
            train_data[method_config.input_key],
            train_data["admissible"],
            eval_data[method_config.input_key],
            eval_data["admissible"],
            context=context_columns,  # type: ignore
            method_config=method_config,
            alpha=alpha,
            num_trials=num_trials,
            seed=i,
          ),
        )
      )

  if detail_path is not None:
    upath.UPath(detail_path).mkdir(parents=True, exist_ok=True)

    # Let's save eval_data to a file:
    eval_data_path = upath.UPath(detail_path) / "eval_data.parquet"
    pd.DataFrame(
      {k: v.tolist() for k, v in eval_data.items() if k != "hidden_states"}
    ).to_parquet(eval_data_path)

  with tqdm(total=len(futures), desc=experiment_config.task) as pbar:
    for chunk_idx, (method_name, future) in enumerate(futures):
      cal_trial_df, test_trial_df = ray.get(future)
      cal_trial_df["method_name"] = method_name
      test_trial_df["method_name"] = method_name
      if detail_path is not None:
        cal_output_path = (
          upath.UPath(detail_path) / "cal" / f"chunk_{chunk_idx}.parquet"
        )
        test_output_path = (
          upath.UPath(detail_path) / "test" / f"chunk_{chunk_idx}.parquet"
        )
        cal_output_path.parent.mkdir(parents=True, exist_ok=True)
        test_output_path.parent.mkdir(parents=True, exist_ok=True)
        cal_trial_df.to_parquet(cal_output_path)
        test_trial_df.to_parquet(test_output_path)
      pbar.update(1)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()

  parser.add_argument("--experiment-cfg", type=str)
  parser.add_argument("--method-cfg", type=str)
  parser.add_argument("--task-alphas-cfg", type=str, required=False)
  parser.add_argument("--alpha-grid-start", type=float)
  parser.add_argument("--alpha-grid-end", type=float)
  parser.add_argument("--alpha-grid-step", type=float)
  parser.add_argument("--output-dir", type=str, default="./outputs")
  parser.add_argument("--num-trials", type=int, default=100)
  parser.add_argument("--debug", action="store_true")
  parser.add_argument("--data-cache-dir", type=str)
  parser.add_argument("--batch-size", type=int, default=20)
  parser.add_argument("--context-map", type=str, help="Path to context map config")
  args = parser.parse_args()

  if args.data_cache_dir is not None:
    logger.info(f"Setting data cache dir to {args.data_cache_dir}")
    data_loaders.DatasetCachingUtility.set_base_cache_dir(args.data_cache_dir)

  experiments = OmegaConf.load(args.experiment_cfg)
  methods = OmegaConf.load(args.method_cfg)
  context_map = tp.cast(
    dict[str, str],
    OmegaConf.to_object(OmegaConf.load(args.context_map)) if args.context_map else {},
  )

  timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  output_folder = upath.UPath(args.output_dir) / f"run_{timestamp}"
  output_folder.mkdir(parents=True, exist_ok=True)

  for experiment in tqdm(experiments, desc="Running experiments"):
    resolved_methods = resolve_method_cfg(methods, experiment)
    experiment_config = parse_config(experiment, ExperimentConfig)
    task = experiment_config.task

    # Get alphas either from task config or grid specification
    if (
      args.alpha_grid_start is not None
      and args.alpha_grid_end is not None
      and args.alpha_grid_step is not None
    ):
      alphas = np.arange(
        args.alpha_grid_start,
        args.alpha_grid_end + args.alpha_grid_step / 2,
        args.alpha_grid_step,
      )
      alphas = list(sorted(alphas))
    else:
      # Get task-specific alphas from config
      if args.task_alphas_cfg is None:
        logger.error(
          "Either --task-alphas-cfg or --alpha-grid-{start,end,step} must be specified"
        )
        continue

      task_alphas = OmegaConf.load(args.task_alphas_cfg)
      if task not in task_alphas:
        logger.warning(f"No alphas specified for task {task}, skipping...")
        continue

      alphas = task_alphas[task].alphas
      alphas = list(sorted(alphas))

    logger.info(f"Running task {task} with alphas: {alphas}")

    method_configs = {
      k: parse_config(v, MethodConfig) for k, v in resolved_methods.items()
    }

    exp_dir = upath.UPath(experiment.infer).name
    logger.info(f"Running experiment {exp_dir}")
    detail_path = output_folder / f"{exp_dir}.parquet"
    run_single_experiment(
      experiment_config,
      num_trials=args.num_trials,
      detail_path=(output_folder / f"{exp_dir}.parquet"),
      alphas=alphas,
      method_configs=method_configs,
      debug=args.debug,
      data_cache_dir=args.data_cache_dir,
      context_map={"first_occurrence_indices": "first_occurrence_indices"},
    )
