import functools
import itertools
import json
import math
import multiprocessing as mp
import pathlib
import sys
import tempfile
import time
import typing as tp

import numpy as np
import pandas as pd
import upath
from loguru import logger
from rouge_score import rouge_scorer
from tqdm import tqdm

VALID_LEVELS = {"TRACE", "DEBUG", "INFO", "SUCCESS", "WARNING", "ERROR", "CRITICAL"}


def set_log_level(level: tp.Union[str, int]):
  """
  Set Loguru's global log level
  Args:
      level: Either a string ('DEBUG', 'INFO', etc) or integer level value
  Raises:
      ValueError: If invalid log level provided
  """
  if isinstance(level, str):
    level = level.upper()
    if level not in VALID_LEVELS:
      raise ValueError(f"Invalid log level. Must be one of {VALID_LEVELS}")

  logger.remove()
  logger.add(sys.stderr, level=level)


def parse_split_string(split_string):
  splits = split_string.split(":")

  # Convert percentage strings to floats and handle missing values
  parsed_splits = []
  for split in splits:
    if split:
      parsed_splits.append(
        float(split.replace("%", "")) / 100 if "%" in split else float(split)
      )
    else:
      parsed_splits.append(None)

  # Calculate the missing split value if necessary
  if parsed_splits.count(None) > 1:
    raise ValueError("Only one split size can be left unspecified")

  if None in parsed_splits:
    known_splits_sum = sum(filter(None, parsed_splits))
    missing_index = parsed_splits.index(None)
    parsed_splits[missing_index] = 1.0 - known_splits_sum

  # Ensure the sum of splits is 1.0
  if not np.isclose(sum(parsed_splits), 1.0):
    raise ValueError("The sum of train, calibration, and test sizes must be 1.0")

  return parsed_splits


def split_dataset(
  *arrays, splits: tp.Sequence[float], num_splits: int = 2, random_state=None
):
  if len(splits) != num_splits and len(splits) != num_splits - 1:
    raise ValueError(
      f"Number of split proportions must be {num_splits} or {num_splits - 1}"
    )

  # If the number of provided splits is one less than num_splits, calculate the last split
  if len(splits) == num_splits - 1:
    splits = list(splits)
    splits.append(1 - sum(splits))

  # Ensure the proportions sum to 1
  if not np.isclose(sum(splits), 1.0):
    raise ValueError("The sum of the split proportions must be 1.0")

  # Create an array of indices and shuffle them
  indices = np.arange(len(arrays[0]))
  if random_state is not None:
    np.random.seed(random_state)
  indices = np.random.permutation(indices)

  # Calculate the split sizes
  split_sizes = [int(split * len(indices)) for split in splits]

  # Ensure the split sizes sum to the length of the array (adjust for rounding)
  split_sizes[-1] = len(indices) - sum(split_sizes[:-1])

  # Split the indices into the specified proportions
  split_indices = np.split(indices, np.cumsum(split_sizes[:-1]))

  # Split each array using the split indices
  split_arrays = []
  for arr in arrays:
    if isinstance(arr, pd.DataFrame):
      split_arr = [arr.iloc[idx] for idx in split_indices]
    else:
      split_arr = [arr[idx] for idx in split_indices]
    split_arrays.append(split_arr)

  # Return the split arrays in a structured form
  if len(arrays) == 1:
    return split_arrays[0]
  return tuple(zip(*split_arrays))


def compute_rouge_L(hyp, ref):
  scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
  return scorer.score(hyp, ref)["rougeL"].fmeasure


def pairwise_comparisons(arr, pairwise_comparison_fn, cache: bool = False):
  def _pairwise_comparisons(pairwise_comparison_fn):
    N = len(arr)
    expanded_matrix = np.empty((N, N), dtype=object)

    for i in range(N):
      expanded_matrix[i, i] = pairwise_comparison_fn(arr[i], arr[i])

    for i, j in itertools.combinations(range(N), 2):
      result = pairwise_comparison_fn(arr[i], arr[j])
      expanded_matrix[i, j] = result
      expanded_matrix[j, i] = result

    return expanded_matrix

  if cache:
    pairwise_comparison_fn = functools.lru_cache(maxsize=None)(pairwise_comparison_fn)

  return _pairwise_comparisons(pairwise_comparison_fn)


def _task_runner(task, comparison_fn: bool, use_cache: bool = True):
  idx, arr = task
  try:
    result = pairwise_comparisons(arr, comparison_fn, cache=use_cache)
  except Exception as e:
    print(f"Error in task {idx}: {e}")
    result = None

  return idx, result


def pairwise_comparisons_multi(
  arr: np.ndarray,
  pairwise_comparison_fn,
  use_cache: bool = True,
  concurrency: int | None = None,
):
  if concurrency == 1:
    return pairwise_comparisons(arr, pairwise_comparison_fn, cache=use_cache)

  with mp.Pool(processes=mp.cpu_count()) as pool:
    tasks = list(enumerate(arr))
    results = []

    queued_tasks = pool.imap_unordered(
      functools.partial(
        _task_runner,
        comparison_fn=pairwise_comparison_fn,
        use_cache=use_cache,
      ),
      tasks,
    )
    for result in tqdm(
      queued_tasks,
      total=len(tasks),
    ):
      # print(f"received, {type(result[1])}")
      results.append(result)

  return np.stack([result for _, result in sorted(results, key=lambda x: x[0])])


def _task_runner2(task, *, comparison_fn: bool):
  idx, arr, outfile = task
  try:
    result = pairwise_comparisons(arr, comparison_fn, cache=True)
    np.save(outfile, result, allow_pickle=True)
  except Exception as e:
    print(f"Error in task {idx}: {e}")
    result = None

  return (idx, outfile)


def pairwise_comparisons_multi_v2(
  arr: np.ndarray,
  pairwise_comparison_fn,
  use_cache: bool = True,
  concurrency: int | None = None,
):
  from loguru import logger

  if concurrency == 1:
    return pairwise_comparisons(arr, pairwise_comparison_fn, cache=use_cache)

  unique_elements, unique_idx = np.unique(arr, return_index=True)
  with tempfile.TemporaryDirectory() as temp_dir:
    temp_dir = pathlib.Path(temp_dir)

    with mp.Pool(processes=mp.cpu_count()) as pool:
      # tasks = list(enumerate(arr))
      tasks = []
      logger.info("WRITING TO TMP DIR")
      for idx, row in enumerate(arr):
        outfile = temp_dir / f"tmp-{idx}.npy"
        tasks.append((idx, row, outfile))

      results = []
      queued_tasks = pool.imap_unordered(
        functools.partial(
          _task_runner2,
          comparison_fn=pairwise_comparison_fn,
        ),
        tasks,
      )
      for idx, outfile in tqdm(
        queued_tasks,
        total=len(tasks),
      ):
        # print(f"received, {type(result[1])}")
        # results.append(result)
        result = np.load(outfile, allow_pickle=True)

        # results.append((idx, result))
  raise Exception()
  return np.stack([result for _, result in sorted(results, key=lambda x: x[0])])


def prepare_data_for_calibration(df: pd.DataFrame, randomize: bool = False):
  def shuffle_row(row):
    combined = list(zip(row["sample_idx"], row["logprobs"], row["admissible"]))
    np.random.shuffle(combined)
    sample_idx_shuffled, logprobs_shuffled, admissible_shuffled = zip(*combined)
    row["sample_idx"] = list(sample_idx_shuffled)
    row["logprobs"] = list(logprobs_shuffled)
    row["admissible"] = list(admissible_shuffled)
    return row

  df = df.groupby(["dataset_idx"])[
    [
      "sample_idx",
      "logprobs",
      "normalized_logprobs",
      "admissible",
      "generated_text",
    ]
  ].agg(list)
  if randomize:
    df = df.apply(shuffle_row, axis=1)

  df["num_admissible"] = df["admissible"].apply(lambda x: np.sum(x))
  df["min_k"] = df["admissible"].apply(lambda x: np.argmax(np.cumsum(x) > 0))
  df["logprobs"] = df["logprobs"].apply(lambda x: x[0])
  df["normalized_logprobs"] = df["normalized_logprobs"].apply(lambda x: x[0])
  # set min_k to infinity if no admissible samples
  df.loc[df["num_admissible"] == 0, "min_k"] = len(df.iloc[0]["sample_idx"])
  # calculate \hat{p}_mle
  df["p_mle"] = (df["num_admissible"] + 1) / (df["sample_idx"].apply(len) + 2)

  def compute_effective_size(arr):
    seen = set()
    sizes = []
    # print(arr[:2])
    for x in arr:
      if x in seen:
        sizes.append(len(seen))
      else:
        seen.add(x)
        sizes.append(len(seen))
    return sizes

  df["dedup_size"] = df["generated_text"].apply(compute_effective_size)
  # calculate
  return df


def conformal_quantile(
  x: np.ndarray, *, alpha: float, quantile_method: str = "inverted_cdf"
) -> tuple[float, float]:
  n = x.shape[0]
  adjusted_alpha = math.ceil((n + 1) * alpha) / n
  idx = int(np.ceil((n + 1) * alpha)) - 1

  if idx < 0:
    raise ValueError(f"idx is negative: {idx}")

  # return np.sort(x)[idx], adjusted_alpha
  q_n = float(np.quantile(x, adjusted_alpha, method=quantile_method))
  return q_n, adjusted_alpha


def conformal_p_values(
  calibration_scores: np.ndarray,
  test_scores: np.ndarray,
  *,
  smooth: bool = False,
  is_conformity: bool = False,
) -> np.ndarray:
  if smooth:
    us = np.random.uniform(0, 1, size=test_scores.shape[0])
  else:
    us = np.ones(test_scores.shape[0])
  comparison_op = np.less if is_conformity else np.greater
  unnormalized = comparison_op(calibration_scores, test_scores[:, np.newaxis]).sum(
    axis=1
  ) + us * (1 + (calibration_scores == test_scores[:, np.newaxis]).sum(axis=1))

  return unnormalized / (calibration_scores.shape[0] + 1)


def compute_normalized_logprobs(token_logprobs: np.ndarray, *, alpha: float = 0.6):
  length = len(token_logprobs)
  normalizer = (5 + length) ** alpha / 6**alpha

  return np.exp(np.sum(token_logprobs) / normalizer)


def find_last_false_index(bool_array, *, default_value=0):
  "Given a 2D boolean array, find the index of the last False value in each row."
  # Ensure the input is a NumPy array
  bool_array = np.asarray(bool_array)

  # Handle 1D array case
  if bool_array.ndim == 1:
    bool_array = bool_array[np.newaxis, :]

  indices = []
  for row in bool_array:
    for i in range(len(row) - 1, -1, -1):
      if not row[i]:
        indices.append(i)
        break
    else:
      indices.append(default_value)
  return np.array(indices)


def generate_color_map(keys, palette_name="deep"):
  """
  Generate a color map based on unique methods in the dataframe.

  Parameters:
  df (pandas.DataFrame): The dataframe containing the data.
  method_column (str): The name of the column containing method names.
  palette_name (str): Name of the seaborn palette to use.

  Returns:
  dict: A dictionary mapping each unique method to a color.
  """
  # Get unique methods
  import seaborn as sns

  # Generate a color palette
  palette = sns.color_palette(palette_name, n_colors=len(keys))

  # Create and return the color dictionary
  return dict(zip(keys, palette.as_hex()))


def load_json(url):
  path = upath.UPath(url)

  with path.open("r") as f:
    return json.load(f)


def seed_everything(seed: int = 0):
  np.random.seed(seed)
  import torch

  torch.manual_seed(seed)
  import random

  random.seed(seed)
  import os

  os.environ["PYTHONHASHSEED"] = str(seed)


def load_npz(fpath):
  fpath = upath.UPath(fpath)

  with fpath.open("rb") as f:
    data = np.load(f)
    if isinstance(data, np.ndarray):
      return data
    return {k: data[k] for k in data.files}


def load_yaml(pathlike):
  import yaml

  path = upath.UPath(pathlike)
  with path.open("r") as f:
    return yaml.safe_load(f)


class PerfLogger:
  def __init__(
    self, log_fn: tp.Callable[[str], None] = logger.info, min_elapsed: float = 0
  ):
    self._logs = [("start", time.time())]
    self.log_fn = log_fn
    self.min_elapsed = min_elapsed

  def log(self, tag: str, *args, **kwargs):
    current_time = time.time()
    elapsed = current_time - self._logs[-1][1]

    if elapsed >= self.min_elapsed:
      self._logs.append((tag, current_time))
      self.log_fn(f"[{elapsed:.02f}s] {tag}", *args, **kwargs)
    else:
      # Update the last log time without logging
      self._logs[-1] = (self._logs[-1][0], current_time)


def boolean_mask_to_index_lists(mask):
  """
  Convert a 2D boolean mask into a list of index lists.
  Each inner list contains the indices where True occurs in that row.

  Args:
      mask (np.ndarray): 2D boolean array

  Returns:
      list[list[int]]: List of index lists, one per row, containing indices
      where True values occur

  Example:
      >>> mask = np.array([[True, False], [True, True]])
      >>> boolean_mask_to_index_lists(mask)
      [[0], [0, 1]]
  """
  return [np.where(row)[0].tolist() for row in mask]


def index_lists_to_boolean_mask(index_lists, num_cols):
  """
  Convert a list of index lists into a 2D boolean mask.
  Each inner list specifies which positions should be True in that row.

  Args:
      index_lists (list[list[int]]): List of lists containing indices to set True
      num_cols (int): Number of columns in output mask

  Returns:
      np.ndarray: 2D boolean array with True at specified indices

  Example:
      >>> indices = [[0], [0, 1]]
      >>> index_lists_to_boolean_mask(indices, num_cols=2)
      array([[ True, False],
             [ True,  True]])
  """
  num_rows = len(index_lists)
  mask = np.zeros((num_rows, num_cols), dtype=bool)

  for row_idx, col_indices in enumerate(index_lists):
    mask[row_idx, col_indices] = True

  return mask


def compute_first_occurrence_indices(items):
  """For each item, store the index of its first occurrence in the list.

  For example:
  [a, b, a, c] -> [0, 1, 0, 3]
  This means:
  - items[0] first occurred at index 0
  - items[1] first occurred at index 1
  - items[2] first occurred at index 0 (it's a duplicate of items[0])
  - items[3] first occurred at index 3

  Args:
      items: List of items to process

  Returns:
      np.ndarray: Array of indices where each value first occurred
  """
  seen = {}  # Store first occurrence of each item
  result = np.zeros(len(items), dtype=np.int32)

  for i, item in enumerate(items):
    if item in seen:
      result[i] = seen[item]  # Point to first occurrence
    else:
      seen[item] = i  # Record first occurrence
      result[i] = i

  return result


def get_memory_usage():
  import os
  import psutil

  process = psutil.Process(os.getpid())
  return process.memory_info().rss
