"""Data processing utilities."""

import hashlib
import logging
import numpy as np
from scipy.stats import rankdata
from typing import Sequence, Tuple, Union

from src.io_utils import collect_results, get_filtered_and_grouped_paths


def generate_sample_sizes(total_samples: int) -> tuple[int, ...]:
    """Generate sample sizes using 1-2-5 pattern for each power of ten."""
    if total_samples < 1:
        return tuple()
    bases = (1, 2, 5)          # 1-2-5 pattern for each power of ten
    result = []
    power = 0
    while True:
        scale = 10 ** power
        for b in bases:
            value = b * scale
            if value > total_samples:
                # Stop once the next milestone exceeds the target
                result.append(total_samples) if result[-1] != total_samples else None
                return tuple(result)
            result.append(value)
            if value == total_samples:
                return tuple(result)
        power += 1


def _dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    Return the non-dominated (Pareto-optimal) points, ordered by cost.
    The frontier is defined as points for which no other point has
    *both* lower cost (x) and lower mean p_harmful (y).

    Parameters
    ----------
    xs, ys : 1-D arrays of equal length
        Coordinates of the candidate points.

    Returns
    -------
    frontier_xs, frontier_ys : 1-D arrays
        Coordinates of the Pareto frontier, sorted by xs ascending.
    """
    order = np.argsort(xs)              # sort by cost
    xs_sorted, ys_sorted = xs[order], ys[order]

    frontier_x, frontier_y = [0], [0]
    best_y_so_far = 0
    for x_val, y_val in zip(xs_sorted, ys_sorted):
        if y_val > best_y_so_far:       # strictly better in y
            frontier_x.append(x_val)
            frontier_y.append(y_val)
            best_y_so_far = y_val
    frontier_x.append(xs_sorted[-1])
    frontier_y.append(frontier_y[-1])
    return np.asarray(frontier_x), np.asarray(frontier_y)


def _non_cumulative_dominance_frontier(xs: np.ndarray, ys: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    Return all points ordered by cost, without dominance filtering.
    This creates a non-cumulative frontier that includes all points.

    Parameters
    ----------
    xs, ys : 1-D arrays of equal length
        Coordinates of the candidate points.

    Returns
    -------
    frontier_xs, frontier_ys : 1-D arrays
        All points, sorted by xs ascending.
    """
    order = np.argsort(xs)              # sort by cost
    xs_sorted, ys_sorted = xs[order], ys[order]

    frontier_x, frontier_y = [0, *xs_sorted], [0, *ys_sorted]

    return np.asarray(frontier_x), np.asarray(frontier_y)


def _pareto_frontier(xs: np.ndarray,
                     ys: np.ndarray,
                     method: str = "basic",
                     **kwargs):
    """Thin wrapper to switch between frontier methods."""
    if method == "basic":
        return _dominance_frontier(xs, ys)
    elif method == "non_cumulative":
        return _non_cumulative_dominance_frontier(xs, ys)
    else:
        raise ValueError(f"Unknown frontier method '{method}'")


class DataFetcher:
    """Handles data fetching with caching."""

    def __init__(self):
        self.cache = {}

    def fetch_data(self, model: str, attack: str, attack_params: dict,
                   dataset_idx: list[int], group_by: set[str]):
        """Common data fetching logic used across all plotting functions."""
        hash_key = hashlib.sha256(
            (model + attack + str(attack_params) + str(dataset_idx) + str(group_by)).encode()
        ).hexdigest()

        if hash_key in self.cache:
            return self.cache[hash_key]

        filter_by = dict(
            model=model,
            attack=attack,
            attack_params=attack_params,
            dataset_params={"idx": dataset_idx},
        )
        paths = get_filtered_and_grouped_paths(filter_by, group_by, force_reload=False)

        results = collect_results(paths, infer_sampling_flops=True)
        assert len(results) == 1, f"Should only have exactly one type of result, got {len(results)}, {list(results.keys())}"

        self.cache[hash_key] = list(results.values())[0]
        return self.cache[hash_key]


def preprocess_data(results: dict[tuple[str, ...], np.ndarray], metric: tuple[str, ...],
                   threshold: float|None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Common data preprocessing logic.

    Args:
        results: dict[str, np.ndarray], shape (B, n_steps, n_samples)
        metric: tuple[str, ...], shape (B, n_steps, n_samples)
        threshold: float|None, shape (B, n_steps, n_samples)

    Returns:
        y: np.ndarray, shape (B, n_steps, n_samples)
        flops_optimization: np.ndarray, shape (B, n_steps)
        flops_sampling_prefill_cache: np.ndarray, shape (B, n_steps)
        flops_sampling_generation: np.ndarray, shape (B, n_steps)
    """
    y = np.array(results[metric])  # (B, n_steps, n_samples)
    if threshold is not None:
        y = y > threshold

    flops_optimization = np.array(results["flops"]) # (B, n_steps)
    flops_sampling_prefill_cache = np.array(results["flops_sampling_prefill_cache"]) # (B, n_steps)
    flops_sampling_generation = np.array(results["flops_sampling_generation"]) # (B, n_steps)


    return y, flops_optimization, flops_sampling_prefill_cache, flops_sampling_generation


def subsample_and_aggregate_n(
    n: Sequence[int] | np.ndarray,
    y: np.ndarray,
    opt_flops: np.ndarray,
    sampling_prefill_flops: np.ndarray,
    sampling_generation_flops: np.ndarray,
    rng: np.random.Generator,
    return_ratio: bool = False,
    n_smoothing: int = 1,
) -> Union[
    Tuple[float, int, int, float],
    Tuple[float, int, int, float, float, float]
]:
    """
    Subsample-and-aggregate using a per-step sampling vector n.

    Parameters
    ----------
    n : Sequence[int]
        Sampling counts per step. n[s] is how many samples to draw (without replacement)
        from y at step s. Steps considered are s=0..S-1 where S=len(n).
        Examples:
          - Non-cumulative at final step k: n=[0,0,...,k]
          - Cumulative with 1 each step and k at final: n=[1,1,1,...,k]
    y : np.ndarray
        Shape (B, n_steps, n_samples). Values to aggregate (we take max over sampled items).
    opt_flops : np.ndarray
        Shape (B, n_steps). Optimization FLOPs per step (pre-sampling).
    sampling_prefill_flops : np.ndarray
        Shape (B, n_steps). Prefill FLOPs per step (paid once at step s if n[s] > 0).
    sampling_generation_flops : np.ndarray
        Shape (B, n_steps). Generation FLOPs per sampled item at step s.
    rng : np.random.Generator
        RNG to drive sampling randomness (used to derive per-iteration seeds).
    return_ratio : bool
        If True, returns (ratio, last_step_idx, last_step_samples, mean_value, opt_flop, sampling_flop)
        where ratio = sampling_flop / (total_flop + 1e-9).
        If False, returns (total_flop, last_step_idx, last_step_samples, mean_value).
    n_smoothing : int
        Number of smoothing iterations for variance reduction.

    Returns
    -------
    tuple
        If return_ratio=False:
            (total_flop, last_step_idx, last_step_samples, mean_value)
        If return_ratio=True:
            (ratio, last_step_idx, last_step_samples, mean_value, opt_flop, sampling_flop)

        `last_step_idx` is len(n)-1 and `last_step_samples` is n[-1], to preserve
        the positional structure of the previous API.
    """
    n_arr = np.asarray(n, dtype=int)
    if n_arr.ndim != 1 or (n_arr < 0).any():
        raise ValueError("n must be a 1D sequence of non-negative integers.")
    S = n_arr.size
    if S == 0:
        raise ValueError("n must contain at least one step.")
    last_step_idx = S - 1
    last_step_samples = np.sum(n_arr)

    B, n_steps, n_total_samples = y.shape
    if S > n_steps:
        raise ValueError(f"len(n)={S} exceeds y's available steps={n_steps}.")
    if (n_arr > n_total_samples).any():
        raise ValueError(
            f"Each n[s] must be <= number of available samples ({n_total_samples})."
        )
    if n_arr.sum() == 0:
        raise ValueError("At least one step must have n[s] > 0 to aggregate a value.")

    # FLOP accounting
    # Optimization FLOPs through all considered steps
    opt_flop = float(np.mean(opt_flops[:, :S].sum(axis=1)))

    # Sampling FLOPs: per step, pay prefill once if we sample there, and generation per sample
    mean_prefill = np.mean(sampling_prefill_flops[:, :S], axis=0)            # (S,)
    mean_gen = np.mean(sampling_generation_flops[:, :S], axis=0)             # (S,)
    sampling_flop = float((mean_gen * n_arr).sum() + (mean_prefill * (n_arr > 0)).sum())

    total_flop = opt_flop + sampling_flop

    # Value estimation with smoothing
    values = []
    # Derive deterministic seeds from rng to avoid altering caller's RNG state too much.
    # If caller wants strict determinism across runs, they should seed `rng` identically.
    for i in range(n_smoothing):
        # Create a child generator using a seed derived from rng
        # (fallback deterministic mix of current draw and iteration index)
        base_seed = int(rng.integers(0, 2**31 - 1)) if rng is not None else 0
        local_rng = np.random.default_rng(base_seed ^ (i * 0x9E3779B1))

        per_step_maxes = []
        for s, k in enumerate(n_arr):
            if k <= 0:
                continue
            # choose k unique indices at step s
            idxs = local_rng.choice(n_total_samples, size=int(k), replace=False)
            # max across the sampled items at this step -> (B,)
            step_max = y[:, s, idxs].max(axis=-1)
            per_step_maxes.append(step_max)
        # Combine across steps by taking the max across step-wise maxima -> (B,)
        # Then mean over batch -> scalar
        stacked = np.stack(per_step_maxes, axis=1)  # (B, n_active_steps)
        run_value = float(stacked.max(axis=1).mean(axis=0))
        values.append(run_value)

    mean_value = float(np.mean(values))

    if return_ratio:
        ratio = sampling_flop / (total_flop + 1e-9)
        return (ratio, last_step_idx, last_step_samples, mean_value, opt_flop, sampling_flop)
    else:
        return (total_flop, last_step_idx, last_step_samples, mean_value)


def _distribute_proportionally(weights: np.ndarray, total: int) -> np.ndarray:
    """
    Distribute `total` items proportionally according to `weights`.

    Parameters
    ----------
    weights : np.ndarray
        Array of weights (non-negative)
    total : int
        Total number of items to distribute

    Returns
    -------
    np.ndarray
        Array of integers that sum to `total`
    """
    if np.sum(weights) == 0:
        # Edge case: all weights are zero
        return np.zeros(len(weights), dtype=int)

    # Normalize weights to probabilities
    probs = weights / np.sum(weights)

    # Allocate proportionally
    allocated = np.floor(probs * total).astype(int)
    remainder = total - np.sum(allocated)

    # Distribute remaining items to highest fractional parts
    if remainder > 0:
        fractional_parts = probs * total - allocated
        top_indices = np.argsort(fractional_parts)[-remainder:]
        allocated[top_indices] += 1

    return allocated


def get_n_schedule(n_steps: int, n_total_sample_budget: int, schedule_type: str, **kwargs) -> np.ndarray[int]:
    """
    Get the n schedule based on the schedule type.

    Parameters
    ----------
    n_steps : int
        Number of steps.
    n_total_sample_budget : int
        Total number of samples to place across steps.
    schedule_type : {"uniform","linear","block","pair","end"}
        Schedule type.
    kwargs :
        For "block":
            - b (int): block size (number of trailing steps forming the block). Must divide budget.
        For "linear" (optional, sensible defaults if omitted):
            - direction (str): "increasing" (default) or "decreasing".
                "increasing" biases samples toward later steps.
                "decreasing" biases samples toward earlier steps.
            - offset (float): nonnegative baseline added to the linear ramp (default 0.0).
                With offset>0, no step is strictly zero-weight unless total budget is 0.

    Returns
    -------
    n : np.ndarray[int] of shape (n_steps,)
        Samples per step.
    """
    if n_steps <= 0:
        raise ValueError("n_steps must be positive")
    if n_total_sample_budget < 0:
        raise ValueError("n_total_sample_budget must be nonnegative")

    schedule_type = schedule_type.lower()

    if schedule_type == "uniform":
        # Rule: the first sample should always be from the last prompt (i.e., ensure >=1 at the last step)
        # Remaining samples are spaced as uniformly as possible across all steps.
        n = np.zeros(n_steps, dtype=int)
        if n_total_sample_budget == 0:
            return n
        # base uniform fill when budget >= steps
        q, r = divmod(n_total_sample_budget, n_steps)
        if q > 0:
            n += q
        # Distribute the remainder r at (approximately) uniformly spaced indices, including the last step.
        if r > 0:
            # Choose r indices from linspace(0, n_steps-1, r, endpoint=True), rounded to nearest int and made unique.
            # When r == n_steps, this naturally picks every index (but we already handled that in q).
            idx = np.linspace(0, n_steps - 1, r, endpoint=True)
            idx = np.rint(idx).astype(int)
            n[idx] += 1

        # Enforce: "the first sample should always be from the last prompt"
        # Ensure last step has at least 1 by moving one sample from the most loaded earlier step if necessary.
        if n[-1] == 0:
            # find a donor index (prefer the earliest with n[i] > 0)
            donors = np.where(n[:-1] > 0)[0]
            if donors.size == 0:
                # This only happens if budget==0, but we handled that. Still, guard:
                n[-1] = 1
            else:
                donor = donors[0]
                n[donor] -= 1
                n[-1] += 1
        return n

    elif schedule_type == "linear":
        # Sampling density proportional to a linear function across steps.
        # By default, increase linearly toward the end (later steps get more weight).
        direction = kwargs.get("direction", "increasing")
        offset = float(kwargs.get("offset", 0.0))
        if offset < 0:
            raise ValueError("offset must be nonnegative")

        if direction not in ("increasing", "decreasing"):
            raise ValueError("direction must be 'increasing' or 'decreasing'")

        # Indices 0..n_steps-1; weights are an affine ramp: w[i] = offset + m*i + c (here just offset + i or reversed)
        if direction == "increasing":
            ramp = np.arange(1, n_steps + 1, dtype=float)  # 1..n_steps (strictly increasing)
        else:  # decreasing
            ramp = np.arange(n_steps, 0, -1, dtype=float)  # n_steps..1

        weights = ramp + offset
        n = _distribute_proportionally(weights, n_total_sample_budget)
        return n
    elif schedule_type == "start":
        n = np.zeros(n_steps, dtype=int)
        n[0] = n_total_sample_budget
        return n
    elif schedule_type == "block":
        # Place a trailing block of size b at the end.
        # Distribute floor(budget / b) to each of the last b steps,
        # then spread the remainder (+1) across the last r steps of the block.
        if "b" not in kwargs:
            b = min(n_steps, 5)
        else:
            b = int(kwargs["b"])
        if b <= 0 or b > n_steps:
            raise ValueError("b must be in [1, n_steps]")

        n = np.zeros(n_steps, dtype=int)
        if n_total_sample_budget == 0:
            return n

        q, r = divmod(n_total_sample_budget, b)
        # Base fill over the block
        n[-b:] = q
        # Distribute remainder to the last r steps of the block (latest steps)
        if r > 0:
            n[-r:] += 1
        return n

    elif schedule_type == "pair":
        # only useful for baselining - doesn't make sense to compare to other schedules
        # because we exceed the total sample budget
        n = np.ones(n_steps, dtype=int)
        if n_steps > 0:
            n[-1] = n_total_sample_budget
        return n

    elif schedule_type == "end":
        n = np.zeros(n_steps, dtype=int)
        if n_steps > 0:
            n[-1] = n_total_sample_budget
        return n

    else:
        raise ValueError(f"Unknown schedule type '{schedule_type}'")


# Global data fetcher instance
_data_fetcher = DataFetcher()

def fetch_data(model: str, attack: str, attack_params: dict,
               dataset_idx: list[int], group_by: set[str]):
    """Global function for fetching data using the singleton fetcher."""
    return _data_fetcher.fetch_data(model, attack, attack_params, dataset_idx, group_by)