import dataclasses
import multiprocessing as mp
import typing as tp

import numpy as np
import pandas as pd
import upath
from omegaconf import OmegaConf
from tqdm.auto import tqdm

from generative_prediction_sets import utils

# ---------------------------------------------------------------------------- #
#                          Loading collected clm data                          #
# ---------------------------------------------------------------------------- #


def get_single_trial_details(
  trial_data: dict[str, np.ndarray], epsilons: tp.Sequence[float] | None = None
):
  """Loads trial results dictionary containing evaluation metrics for each epsilon risk level.

  Keys:
      configs: np.ndarray [num_epsilons, 3]
          Selected (lambda_1, lambda_2, lambda_3) thresholds for each epsilon, where:
          - lambda_1: similarity threshold for diversity filtering
          - lambda_2: quality threshold for item rejection
          - lambda_3: set score threshold for determining set size

      # Loss metrics
      L_avg: np.ndarray [num_epsilons]
          Average loss (fraction of examples where true answer not in predicted set)
      L_worst_pred_avg: np.ndarray [num_epsilons]
          Worst-case average loss when examples are binned by predicted set size
      Ls: np.ndarray [num_examples]
          Binary loss for each example (0 if true answer in set, 1 if not)

      # Set size metrics
      C_size_avg: np.ndarray [num_epsilons]
          Average size of predicted sets, normalized by max_generations
      C_samples_avg: np.ndarray [num_epsilons]
          Average number of samples needed, normalized by max_generations
      C_excess_avg: np.ndarray [num_epsilons]
          Average number of excess samples beyond oracle minimum
      C_relative_excess_avg: np.ndarray [num_epsilons]
          Relative excess samples (excess_samples / total_samples)
      C_obj_avg: np.ndarray [num_epsilons]
          Combined objective (C_samples_avg + C_size_avg)

      # Detailed prediction info
      C_indices: np.ndarray [num_examples]
          Indices where each example's prediction set was cut off
      kept_mask: np.ndarray [num_examples, max_generations]
          Boolean mask indicating which items passed rejection criteria
          (both similarity and quality thresholds)

  Note:
      - All metrics are computed for each epsilon risk level in DEFAULT_EPSILONS
      - Normalized metrics are divided by max_generations
      - Oracle minimum refers to minimum samples needed to include true answer
  """
  SUMMARY_COLS = [
    "L_avg",
    "L_worst_pred_avg",
    "C_size_avg",
    "C_samples_avg",
    "C_excess_avg",
    "C_relative_excess_avg",
    "C_obj_avg",
    "configs",
  ]

  detail_dfs = []
  summaries = []
  for eps_idx in range(trial_data["configs"].shape[0]):
    kept_mask = trial_data["kept_mask"][eps_idx]
    set_sizes = np.cumsum(kept_mask, axis=-1)
    C_indices = trial_data["C_indices"][eps_idx].astype(int)
    num_examples = kept_mask.shape[0]
    # actual set size
    C_sizes = set_sizes[np.arange(num_examples), C_indices]
    # number of samples computed
    C_samples = C_indices + 1
    # print(C_indices[0])
    summary = {k: trial_data[k][eps_idx] for k in SUMMARY_COLS}
    is_cfg_valid = ~np.all(np.isnan(trial_data["configs"][eps_idx]))
    summary["configs"] = tuple(summary["configs"])
    summary["is_cfg_valid"] = is_cfg_valid
    detail_df = pd.DataFrame(
      {
        "eps_idx": eps_idx,
        "seq_idx": np.arange(num_examples),
        # "kept_mask": kept_mask.tolist(),
        # "C_indices": C_indices,
        "set_size": C_sizes,
        "num_samples": C_samples,
        "is_cfg_valid": is_cfg_valid,
        "Ls": trial_data["Ls"][eps_idx],
      }
    )
    if epsilons is not None:
      summary["epsilon"] = epsilons[eps_idx]
      detail_df["epsilon"] = epsilons[eps_idx]

    summaries.append(summary)
    detail_dfs.append(detail_df)
  return pd.concat(detail_dfs), pd.DataFrame(summaries)


def get_clm_method_name(method_cfg):
  if method_cfg["rejection"]:
    return "{scoring}_reject".format(**method_cfg)
  return "{scoring}".format(**method_cfg)


def extract_trial_and_method_from_file_name(input_string):
  import re

  pattern = r"^trial_(\d+)_method_(\d+).npz$"
  match = re.match(pattern, input_string)

  if match:
    trial_id = int(match.group(1))
    method_id = int(match.group(2))
    return trial_id, method_id
  else:
    return None, None


@dataclasses.dataclass
class ClmResult:
  summary: pd.DataFrame
  details: pd.DataFrame
  cfg: dict


def load_clm_trials(trial_root):
  trial_root = upath.UPath(trial_root)
  trial_cfg = OmegaConf.load(trial_root / "config.yaml")
  epsilons = trial_cfg.epsilons
  methods = [get_clm_method_name(method_cfg) for method_cfg in trial_cfg.methods]
  trial_data_root = trial_root / "trial_results"

  all_trial_details = []
  all_trial_summaries = []
  for trial_file in trial_data_root.glob("trial_*_method_*.npz"):
    trial_data = utils.load_npz(trial_file)

    trial_details, trial_summary = get_single_trial_details(
      trial_data, epsilons=epsilons
    )
    trial_id, method_id = extract_trial_and_method_from_file_name(trial_file.name)
    trial_summary["trial_id"] = trial_id
    trial_summary["method"] = methods[method_id]
    trial_details["trial_id"] = trial_id
    trial_details["method"] = methods[method_id]

    all_trial_details.append(trial_details)
    all_trial_summaries.append(trial_summary)
    #   return pd.concat(all_trial_details), pd.concat(all_trial_summaries)

  return ClmResult(
    summary=pd.concat(all_trial_summaries),
    details=pd.concat(all_trial_details).astype({"method": "category"}),
    cfg=OmegaConf.to_container(trial_cfg),
  )


def _load_clm_worker(clm_dir):
  result = load_clm_trials(clm_dir)
  result.summary["exp_name"] = clm_dir.name
  result.details["exp_name"] = clm_dir.name
  result.details = result.details.astype({"method": "category", "exp_name": "category"})
  return result


def load_clm_results(clm_data_root, limit=None):
  """Loads results from a directory of CLM experiments.

  Assumes the following structure:
  clm_data_root/
    exp_name_1/
      trial_results/
        trial_0_method_0.npz
        ...
      config.yaml
    ...
    exp_name_n/
      ... (same structure as above)
    ...

  Each CLM experiment might have multiple trials, and each trial might have multiple methods.
  returns a `ClmResult`.
  The `.summary` dataframe contains the average results for each trial and method, and the `.details` dataframe contains the raw results for each trial and method.
  Both dataframes also contain an "exp_name" column, corresponding to the directory name relative to clm_data_root.


  """
  clm_data_root = upath.UPath(clm_data_root)
  exp_dirs = [d for d in clm_data_root.iterdir() if d.is_dir()]
  exp_dirs = [d for d in exp_dirs if (d / "config.yaml").exists()]
  if limit is not None:
    exp_dirs = exp_dirs[:limit]

  with mp.Pool(mp.cpu_count()) as pool:
    # from multiprocessing.dummy import Pool
    # with Pool() as pool:
    details = []
    summaries = []
    for result in tqdm(
      pool.imap_unordered(_load_clm_worker, exp_dirs), total=len(exp_dirs)
    ):
      details.append(result.details)
      summaries.append(result.summary)

    return pd.concat(summaries), pd.concat(details).astype(
      {"exp_name": "category", "method": "category"}
    )
