import os
import re
import json
import torch
import random
import numpy as np
from pathlib import Path
from tqdm import tqdm


# regex to capture "start_end.pt"
_RANGE_RE = re.compile(r"(\d+)_(\d+)\.pt$")

def _sorted_layer_files(layer_path: Path):
    files = []
    for name in os.listdir(layer_path):
        m = _RANGE_RE.match(name)
        if m:
            start = int(m.group(1))
            files.append((start, name))
    files.sort(key=lambda x: x[0])  # sort by numeric start id
    return [layer_path / name for _, name in files]

def load_feature_activations(layer, feature_indices, path=Path("activations"), limit=None):
    """
    Load activations for multiple features in a layer.

    Args:
        layer: layer index (int)
        feature_indices: int or list[int], feature indices to load
        path: base directory
        limit: optional limit on number of files to read

    Returns:
        dict: {feature_idx: {"values": [...], "indices": [...]}}
    """
    if isinstance(feature_indices, int):
        feature_indices = [feature_indices]

    # Prepare output dict
    data = {f_idx: {"values": [], "indices": []} for f_idx in feature_indices}

    layer_path = path / f"layer_{layer}"
    files = _sorted_layer_files(layer_path)
    if limit is not None:
        files = files[:limit]

    for file in tqdm(files):
        #print(file.name)
        res = torch.load(file, map_location="cpu")

        for sample_id, act in sorted(res.items(), key=lambda x: x[0]):
            sparse_vals = act["values"].to_dense()
            features = act["indices"]["features"]
            tokens = act["indices"]["tokens"]

            for f_idx in feature_indices:
                val = sparse_vals[f_idx].item()

                if val != 0.0:
                    if f_idx in features:
                        pos = (features == f_idx).nonzero(as_tuple=True)[0].item()
                        token_idx = tokens[pos].item()
                    else:
                        token_idx = None
                else:
                    token_idx = None

                data[f_idx]["values"].append(val)
                data[f_idx]["indices"].append(token_idx)

    return data

def stratified_tail_sampling(values, indices=None, scheme=None, seed=None, sort=True):
    """
    Stratified quantile sampling with oversampling of the tail.
    Ensures no duplicates across strata.

    Args:
        values (list[float]): Input values.
        indices (list[int] or None): Optional token indices.
        scheme (dict): Mapping of {quantile_threshold: n_samples}.
                       Example: {0.5: 50, 0.25: 50} →
                       50 from top 50%, 50 from top 25%.
                       Lower quantiles (stricter) are sampled first.
        seed (int): Random seed.
        sort (bool): If True, sort output by value descending.

    Returns:
        list of (idx, value, token_idx)
    """
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)

    arr = np.array(values)
    nonzero_idx = np.where(arr != 0)[0]
    if len(nonzero_idx) == 0:
        return []

    samples = []
    used = set()

    # sort thresholds so stricter (smaller q) is sampled first
    for q in sorted(scheme.keys()):
        n = scheme[q]
        cutoff = np.quantile(arr[nonzero_idx], 1 - q)
        eligible = nonzero_idx[arr[nonzero_idx] >= cutoff]
        eligible = list(set(eligible) - used)  # remove already sampled
        if len(eligible) == 0:
            continue

        chosen = np.random.choice(eligible, size=min(n, len(eligible)), replace=False)
        for idx in chosen:
            token_idx = indices[idx] if indices is not None else None
            samples.append((int(idx), float(arr[idx]), token_idx))
            used.add(idx)

    if sort:
        samples.sort(key=lambda x: x[1], reverse=True)

    return samples


import numpy as np
import random

import numpy as np
import random
from collections import Counter

import numpy as np
import random

def inverse_freq_sampling(
    values,
    indices=None,
    n_samples=100,
    n_bins=10,
    alpha=0.8,
    seed=None,
    sort=True,
    min_top_quota=0,       # min samples forced from the top quantile (bin n_bins-1)
    max_prob_cap=0.2,      # cap max probability for any single example (None to disable)
    replace_if_needed=True # if not enough unique candidates exist, allow replacement
):
    """
    Inverse-frequency weighted sampling with tail safeguards.
    Returns:
        samples: list of (idx, value, token_idx)
    """
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)

    arr = np.array(values, dtype=float)
    n_total = len(arr)
    if n_total == 0:
        return []

    # We only consider strictly positive (nonzero) values as candidates
    nonzero_idx = np.where(arr > 0)[0]
    if len(nonzero_idx) == 0:
        return []

    arr_nz = arr[nonzero_idx]

    # Bin edges using quantiles on the nonzero subset
    quantiles = np.linspace(0, 100, n_bins + 1)
    edges = np.percentile(arr_nz, quantiles)
    edges[0] -= 1e-12
    edges[-1] += 1e-12

    # Assign bins
    bins = np.digitize(arr_nz, edges) - 1  # 0..n_bins-1
    bin_counts = np.bincount(bins, minlength=n_bins).astype(float)
    bin_counts[bin_counts == 0] = 1.0  # avoid division by zero

    # Per-sample weight = 1 / (freq ^ alpha)
    weights = 1.0 / (bin_counts[bins] ** alpha)
    probs = weights / weights.sum()

    # Optional: cap extreme per-sample probability
    if max_prob_cap is not None:
        probs = np.minimum(probs, max_prob_cap)
        if probs.sum() <= 0:
            probs = np.ones_like(probs) / len(probs)
        else:
            probs = probs / probs.sum()

    # Top bin indices
    top_bin_mask = (bins == (n_bins - 1))
    top_bin_indices_rel = np.where(top_bin_mask)[0]
    top_bin_indices = nonzero_idx[top_bin_indices_rel] if len(top_bin_indices_rel) > 0 else np.array([], dtype=int)

    selected_abs = []

    # Force minimum quota from top quantile
    # Force minimum quota from top quantile if requested
    if min_top_quota > 0:
        remaining_quota = min_top_quota
        current_bin = n_bins - 1
        forced_abs = []

        while remaining_quota > 0 and current_bin >= 0:
            bin_mask = (bins == current_bin)
            bin_indices_rel = np.where(bin_mask)[0]
            bin_indices = nonzero_idx[bin_indices_rel] if len(bin_indices_rel) > 0 else np.array([], dtype=int)

            if len(bin_indices) > 0:
                k = min(remaining_quota, len(bin_indices))
                if k <= len(bin_indices):
                    chosen = np.random.choice(bin_indices, size=k, replace=False)
                else:
                    chosen = np.random.choice(bin_indices, size=k, replace=True)
                forced_abs.extend(chosen)
                remaining_quota -= len(chosen)

            current_bin -= 1

        selected_abs.extend(forced_abs)


    # Remaining to sample
    remaining = max(0, n_samples - len(selected_abs))
    candidate_rel = np.arange(len(arr_nz))

    if remaining > 0:
        if remaining <= len(arr_nz) and not replace_if_needed:
            chosen_rel = np.random.choice(candidate_rel, size=remaining, replace=False, p=probs)
        else:
            chosen_rel = np.random.choice(candidate_rel, size=remaining, replace=True, p=probs)
        chosen_abs = nonzero_idx[chosen_rel]
        selected_abs.extend(list(chosen_abs))

    if not replace_if_needed:
        selected_abs = list(dict.fromkeys(selected_abs))  # enforce uniqueness

    # Format as list of tuples
    samples = []
    for idx in selected_abs:
        token_idx = indices[idx] if indices is not None else None
        samples.append((int(idx), float(arr[idx]), token_idx))

    if sort:
        samples.sort(key=lambda x: x[1], reverse=True)

    return samples




def load_and_sample_activations(
    layer,
    indices,
    activations_path,
    out_path="",
    n_samples=100,
    alpha=0.8,
    n_bins=10,
    seed=None,
    sort_descending=True, 
    min_top_quota=0
):
    """
    Load activations for given features, sample using stratified tail scheme,
    skip already processed features, and save all results in one JSON.

    If `indices` contains a negative value, the function will just return
    everything from the JSON file if it exists.
    """
    out_path = Path(out_path) / f"layer_{layer}.json"
    layer_activations = {}

    # --- Load existing file if exists ---
    if out_path.exists():
        print(f"Loading existing samples from {out_path}")
        with open(out_path, "r") as f:
            layer_activations = json.load(f)
        # Ensure keys are ints
        layer_activations = {int(k): v for k, v in layer_activations.items()}

    # --- If indices < 0 → just return what's in the file ---
    if len(indices) == 0:
        if not layer_activations:
            raise ValueError(f"No existing file found at {out_path} to load all features from.")
        print(f"Returning all {len(layer_activations)} features from {out_path}")
        return layer_activations

    # --- Determine which indices still need sampling ---
    indices_to_process = [i for i in indices if i not in layer_activations]
    if not indices_to_process:
        print("All requested features already processed.")
        return layer_activations

    print(f"Loading activations for {len(indices_to_process)} new features")
    data = load_feature_activations(layer, indices_to_process, path=Path(activations_path))

    print("Sampling activations")
    for feature_idx, d in tqdm(data.items()):
        acts = d['values']
        idxs = d['indices']
        maxact = inverse_freq_sampling(
                    values=acts, 
                    indices=idxs, 
                    n_samples=n_samples, 
                    n_bins=n_bins,   # deciles
                    alpha=alpha,   # how strongly to upweight rare bins
                    seed=seed, 
                    min_top_quota=min_top_quota
                )
        if sort_descending:
            maxact = sorted(maxact, key=lambda x: x[1], reverse=True)
        layer_activations[int(feature_idx)] = [list(x) for x in maxact]

    # --- Save updated file ---
    with open(out_path, "w") as f:
        json.dump(layer_activations, f)

    return layer_activations
