import abc
import concurrent.futures
import dataclasses
import functools
import pathlib
import typing as tp

import datasets as ds
import numpy as np
import pandas as pd
import upath
from loguru import logger
from pydantic import BaseModel
from tqdm.auto import tqdm

from generative_prediction_sets import utils


def load_experiment(
  root: pathlib.Path,
  hidden_states_path="gt-score.parquet",
  infer_file="infer.parquet",
  eval_file="eval.json",
  load_hidden_states=False,
  eval_key="eval_results",
  admissibility_key="correct",
  num_samples: int | None = None,
):
  """Loads and merges experiment data from inference, evaluation, and hidden states files.

  Args:
      root (pathlib.Path): Root directory containing experiment files
      hidden_states_path (str, optional): Path to hidden states parquet file. Defaults to "gt-score.parquet".
      infer_file (str, optional): Path to inference results parquet file. Defaults to "infer.parquet".
      eval_file (str, optional): Path to evaluation results JSON file. Defaults to "eval.json".
      load_hidden_states (bool, optional): Whether to load hidden states. Defaults to False.
      eval_key (str, optional): Key in eval JSON containing results. Defaults to "eval_results".
      admissibility_key (str, optional): Column name for correctness in eval results. Defaults to "correct".
      num_samples (int | None, optional): Number of samples to select per dataset_idx. If None, use all samples.

  Returns:
      pd.DataFrame: Merged dataframe containing:
          - All columns from inference results
          - 'correct'/'passed' column from eval results (renamed to 'admissible')
          - Hidden states paths and logprob columns from gt-score.parquet

  Requirements:
      - infer_file must be a parquet file with 'dataset_idx' and 'sample_idx' columns
      - eval_file must be a JSON with results under eval_key containing either 'correct' or 'passed' column
      - hidden_states_path must be a parquet file with 'token_logprobs' and 'hidden_states' columns

  Merge Process:
      1. Inference and Eval merge:
         - Requires matching 'dataset_idx' and 'sample_idx' columns in both dataframes
         - Must maintain 1:1 relationship (no rows should be dropped or duplicated)
         - The 'correct'/'passed' column from eval is renamed to 'admissible'

      2. Hidden States merge:
         - Adds 'prompt_token_logprobs' (renamed from 'token_logprobs')
         - Computes 'normalized_prompt_logprobs' from token logprobs
         - Computes 'prompt_logprobs' as sum of token logprobs
         - Converts hidden states paths to absolute paths relative to root

  Raises:
      ValueError: If:
          - Eval results lack required 'correct' or 'passed' column
          - Number of rows in infer_df doesn't match eval_df
          - Merge results in different number of rows than input dataframes
  """

  root = upath.UPath(root)
  infer_df = pd.read_parquet(root / infer_file)
  # ---------------------------------------------------------------------------- #
  #                                     eval                                     #
  # ---------------------------------------------------------------------------- #
  eval_df = pd.DataFrame(utils.load_json(root / eval_file)[eval_key])
  if len(infer_df) != len(eval_df):
    raise ValueError(
      "Infer and eval dataframes have different lengths. {} != {}".format(
        len(infer_df), len(eval_df)
      )
    )
  if "correct" not in eval_df.columns:
    # check if 'passed' is present
    if "passed" in eval_df.columns:
      eval_df = eval_df.rename(columns={"passed": "correct"})
    else:
      raise ValueError("No 'correct' or 'passed' column")
  infer_eval = pd.merge(
    infer_df,
    eval_df[["dataset_idx", "sample_idx", "correct"]],
    on=["dataset_idx", "sample_idx"],
  )
  if len(infer_eval) != len(infer_df):
    raise ValueError(
      "Merge of infer and eval dataframes have different lengths. {} != {}".format(
        len(infer_eval), len(infer_df)
      )
    )
  infer_eval = infer_eval.rename(columns={admissibility_key: "admissible"})
  if num_samples is not None:
    logger.info(f"Selecting {num_samples} samples for dataset")
    infer_eval = infer_eval.groupby("dataset_idx").head(num_samples)

  gt_scores = pd.read_parquet((root / hidden_states_path).as_posix())
  # Merge the dataframes

  # Calculate normalized logprobs first.
  gt_scores = gt_scores.rename(columns=dict(token_logprobs="prompt_token_logprobs"))
  gt_scores["normalized_prompt_logprobs"] = gt_scores["prompt_token_logprobs"].apply(
    utils.compute_normalized_logprobs
  )
  gt_scores["prompt_logprobs"] = gt_scores["prompt_token_logprobs"].apply(
    lambda x: np.sum(x)
  )
  gt_scores["hidden_states"] = gt_scores["hidden_states"].apply(
    lambda x: (root / x).as_posix()
  )
  merged = (
    infer_eval.sort_values(by=["dataset_idx", "sample_idx"])
    .groupby("dataset_idx")
    .agg(list)
    .reset_index()
  )
  merged = pd.merge(
    merged,
    gt_scores[
      [
        "dataset_idx",
        "prompt_logprobs",
        "hidden_states",
        "normalized_prompt_logprobs",
        "prompt_token_logprobs",
      ]
    ],
    on=["dataset_idx"],
  )

  if len(merged) != infer_df.dataset_idx.nunique():
    logger.warning(
      "Merged dataframe has different length than the original dataframe. merged ({}) != infer_df.dataset.nunique({})".format(
        len(merged), infer_df.dataset_idx.nunique()
      )
    )
  merged["first_occurrence_indices"] = merged["generated_text"].apply(
    utils.compute_first_occurrence_indices
  )
  return merged


def load_hidden_states(path: str, *, token_idx: int = -1, layer_idx: int = -1):
  normalized_path = upath.UPath(path)
  # print(normalized_path.as_posix())
  with normalized_path.open("rb") as f:
    npz = np.load(f)
    layers = list(npz.keys())
    layers = sorted([int(k) for k in layers])
    key = layers[layer_idx]

    return npz[str(key)][token_idx]


def load_multiple_hidden_states(
  paths: tp.Sequence[str],
  *,
  token_idx: int = -1,
  layer_idx: int = -1,
  progress: bool = True,
):
  with concurrent.futures.ThreadPoolExecutor() as executor:
    tasks = executor.map(
      functools.partial(load_hidden_states, token_idx=token_idx, layer_idx=layer_idx),
      paths,
    )
    return list(tqdm(tasks, total=len(paths), disable=not progress))


def load_and_mmap_hidden_states_in_batches(
  paths: tp.Sequence[str],
  *,
  tmp_filename: str,
  token_idx: int = -1,
  layer_idx: int = -1,
  batch_size: int = 1000,
  progress: bool = True,
):
  # Load the first item to determine shape and dtype
  first_item = load_hidden_states(paths[0], token_idx=token_idx, layer_idx=layer_idx)
  dtype = first_item.dtype
  item_shape = first_item.shape

  # Create a memmap array
  full_shape = (len(paths),) + item_shape
  memmap_array = np.memmap(tmp_filename, dtype=dtype, mode="w+", shape=full_shape)
  # Process in batches
  for start_idx in tqdm(range(0, len(paths), batch_size), disable=not progress):
    end_idx = min(start_idx + batch_size, len(paths))
    batch_paths = paths[start_idx:end_idx]

    # Load batch
    batch_data = [
      load_hidden_states(path, token_idx=token_idx, layer_idx=layer_idx)
      for path in batch_paths
    ]

    # Write batch to memmap
    memmap_array[start_idx:end_idx] = batch_data

    # Flush to ensure data is written to disk
    memmap_array.flush()

  return memmap_array, tmp_filename


class DatasetCachingUtility:
  _instance: tp.Optional["DatasetCachingUtility"] = None
  _base_cache_dir: upath.UPath = upath.UPath("/mnt/samsung-nvme/.aaai-data-cache")

  def __new__(cls):
    if cls._instance is None:
      cls._instance = super().__new__(cls)
    return cls._instance

  @classmethod
  def set_base_cache_dir(cls, base_cache_dir: str):
    cls._base_cache_dir = upath.UPath(base_cache_dir)

  def get_cache_dir(self, experiment_dir: upath.UPath, config, loader_name):
    import hashlib
    import json

    config_hash = hashlib.md5(json.dumps(config, sort_keys=True).encode()).hexdigest()
    return self._base_cache_dir / loader_name / experiment_dir.name / config_hash

  def get_source_files_mtime(self, source_files):
    return max(
      upath.UPath(file).stat().st_mtime
      for file in source_files
      if upath.UPath(file).exists()
    )


class CachedDataset(abc.ABC):
  """Base class for experiment datasets that handles caching and loading.

  This abstract class provides a framework for loading and caching datasets used in experiments.
  It implements caching behavior to avoid reprocessing datasets when the source files haven't changed.

  The class provides automatic caching functionality by:
  - Maintaining a local cache directory for each dataset
  - Tracking modification times of source files
  - Only reprocessing data when source files have been modified
  - Handling cache invalidation and error recovery

  Attributes:
      experiment_dir (str): Directory containing the experiment data.

  Key Methods:
      load(): Main entry point that handles caching logic and returns the dataset
      get_local_cache_dir(): Gets the cache directory path for this dataset
      load_dataset(): Abstract method to implement actual dataset loading
      get_source_files(): Abstract method to specify files used to build dataset
      config: Abstract property to specify dataset configuration

  To implement a concrete dataset class:
      1. Inherit from this class
      2. Implement load_dataset() to process and return the actual dataset
      3. Implement get_source_files() to return paths of files used to build the dataset
      4. Implement the config property to return configuration parameters

  The caching behavior will automatically handle saving/loading the processed dataset
  based on whether the source files have been modified since the last processing.
  """

  experiment_dir: str

  @abc.abstractmethod
  def load_dataset(self) -> ds.Dataset:
    import inspect

    caller_frame = inspect.currentframe().f_back
    if caller_frame and caller_frame.f_code.co_name != "load":
      logger.warning(
        "Direct call to load_dataset() detected. "
        "You should use load() instead to benefit from caching."
      )
      # Alternatively, raise an error:
      # raise RuntimeError("Please use load() instead of load_dataset() to benefit from caching")
    ...

  @abc.abstractmethod
  def get_source_files(self) -> tp.Sequence[str]: ...

  @property
  @abc.abstractmethod
  def config(self) -> dict: ...

  def get_local_cache_dir(self) -> upath.UPath:
    caching_utility = DatasetCachingUtility()
    return caching_utility.get_cache_dir(
      upath.UPath(self.experiment_dir), self.config, self.__class__.__name__
    )

  def load(self) -> ds.Dataset:
    caching_utility = DatasetCachingUtility()
    cache_dir = self.get_local_cache_dir()
    source_files = self.get_source_files()
    source_mtime = caching_utility.get_source_files_mtime(source_files)

    if cache_dir.exists():
      import os

      try:
        dataset = tp.cast(ds.Dataset, ds.Dataset.load_from_disk(str(cache_dir)))
        cache_mtime = os.path.getmtime(str(cache_dir))

        if cache_mtime >= source_mtime:
          logger.debug(f"Loading cached dataset from {cache_dir}")
          return dataset
        else:
          logger.debug("Source files have changed. Reprocessing dataset.")
      except Exception as e:
        logger.warning(f"Error loading cached dataset: {e}. Reprocessing.")

    logger.debug(
      f"Cache not found or outdated. Processing dataset for {self.experiment_dir}"
    )
    dataset = self.load_dataset()

    logger.info(f"Saving processed dataset to {cache_dir}")
    dataset.save_to_disk(str(cache_dir))

    return dataset


class HiddenStateDataset(BaseModel, CachedDataset):
  experiment_dir: str
  infer_path: str = "infer.parquet"
  gt_score_path: str = "gt-score.parquet"
  eval_path: str = "eval.json"
  eval_key: str = "aux"
  admissibility_key: str = "correct"
  hidden_state_token_idx: int = -1
  hidden_state_layer_idx: int = -1
  num_samples: int = 25

  def get_source_files(self) -> tp.Sequence[str]:
    source_file_names = [self.infer_path, self.gt_score_path, self.eval_path]
    return [
      (upath.UPath(self.experiment_dir) / f).resolve().as_posix()
      for f in source_file_names
    ]

  @property
  def config(self) -> dict:
    return self.model_dump()

  def load_dataset(self) -> ds.Dataset:
    def add_hidden_states(batch):
      return {
        **batch,
        "hidden_states": load_multiple_hidden_states(
          batch["hidden_states"],
          token_idx=self.hidden_state_token_idx,
          layer_idx=self.hidden_state_layer_idx,
          progress=True,
        ),
      }

    df = load_experiment(
      upath.UPath(self.experiment_dir),
      infer_file=self.infer_path,
      hidden_states_path=self.gt_score_path,
      eval_file=self.eval_path,
      eval_key=self.eval_key,
      admissibility_key=self.admissibility_key,
      num_samples=self.num_samples,
    )
    dataset = ds.Dataset.from_pandas(df.reset_index())
    dataset = dataset.map(
      add_hidden_states,
      batched=True,
      batch_size=1000,
      desc="Loading hidden states",
    )
    dataset = dataset.map(
      lambda batch: {
        **batch,
        "p_map": (np.sum(batch["admissible"], axis=1) + 1)
        / (np.shape(batch["admissible"])[-1] + 2),
      },
      batched=True,
      num_proc=8,
      desc="Calculating p_map",
    )
    dataset = dataset.select_columns(
      [
        "dataset_idx",
        "sample_idx",
        "hidden_states",
        "admissible",
        "p_map",
        "prompt_logprobs",
        "normalized_prompt_logprobs",
        "prompt_token_logprobs",
        # "effective_set_sizes",
        "first_occurrence_indices",
        # "generation_logprobs",
        # "normalized_generation_logprobs",
      ]
    )
    return dataset


@dataclasses.dataclass
class DataSplit:
  train: np.ndarray
  test: np.ndarray


def load_splits(path):
  with upath.UPath(path).open("rb") as f:
    splits = np.load(f)

    return DataSplit(train=splits["train"], test=splits["test"])


def get_dataset_split(
  dataset: ds.Dataset, split_indices: np.ndarray, index_column="dataset_idx"
) -> ds.Dataset:
  """Split dataset based on a column of indices.

  Args:
      dataset:  dataset to split
      split_indices: indices to split the dataset on
      index_column: the column in the dataset that indicates the index of each row. Defaults to "dataset_idx".
  """
  dataset_idx_to_seq_idx = {
    dataset_idx: seq_idx for seq_idx, dataset_idx in enumerate(dataset[index_column])
  }
  indices_to_take = [
    dataset_idx_to_seq_idx[dataset_idx]
    for dataset_idx in split_indices
    if dataset_idx in dataset_idx_to_seq_idx
  ]
  return dataset.select(indices_to_take)
