import json
from typing import Tuple
import numpy as np
import random
from scipy.stats import beta
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from rag_setup import PersistentRAG


def exp_from_logprobs(logps):
    logps = np.asarray(logps)
    exps = np.exp(logps)
    return exps 

def minmax01(x, axis=None, eps=1e-12):
    x = np.asarray(x, dtype=np.float64)
    xmin = np.min(x, axis=axis, keepdims=True)
    xmax = np.max(x, axis=axis, keepdims=True)
    rng = np.maximum(xmax - xmin, eps)
    return (x - xmin) / rng, xmin, rng

def inv_minmax01(x_norm, xmin, rng):
    """
    Invert minmax01: given normalized data x_norm and the xmin, rng
    returned by minmax01, reconstruct the original data.

    x = x_norm * rng + xmin
    """
    x_norm = np.asarray(x_norm, dtype=np.float64)
    xmin = np.asarray(xmin, dtype=np.float64)
    rng = np.asarray(rng, dtype=np.float64)
    return x_norm * rng + xmin

def threshold(confidences, labels, alpha, delta):
    """
    Find the highest confidence threshold τ such that, with probability ≥ 1 - delta,
    the true disagreement rate on all samples with confidence ≥ τ is ≤ alpha.

    Args:
        confidences (array-like): Confidence scores for each sample.
        labels (array-like): Binary correctness labels (1 = correct/agreement, 0 = incorrect/disagreement).
        alpha (float): Maximum allowed disagreement rate (e.g., 0.05 for 5%).
        delta (float): Risk tolerance for exceeding alpha (e.g., 0.01 for 99% confidence).

    Returns:
        float: The chosen confidence threshold τ.

    Raises:
        ValueError: If no threshold satisfies the constraints.
    """
    confidences = np.asarray(confidences)
    labels = np.asarray(labels)
    disagreements = 1 - labels  # 1 for error, 0 for correct

    # Unique confidence values sorted descending
    thresholds = np.linspace(1,0.001,1000)

    counts = 0
    for tau in thresholds:
        idx = confidences >= tau
        n = np.sum(idx)
        if n == 0:
            continue
        
        next_idx  = confidences >= (tau - 0.001)
        next_k = np.sum(disagreements[next_idx])
        if next_k == 0:
            continue

        k = np.sum(disagreements[idx])

        # Clopper-Pearson upper bound
        upper = beta.ppf(1 - delta, k + 1, n - k)

        if upper > alpha:
                if counts==0:
                    raise ValueError(f"No threshold found with disagreement ≤ {alpha} at 1-δ={1-delta} confidence.")
                else:
                    return tau
        
        counts += 1

    return 0

def proportion_meeting_threshold(confidence, threshold, inclusive=True):
    """
    Return the proportion of values that meet the threshold.
    - confidence: array-like of numbers (list/np.array)
    - threshold: numeric cutoff
    - inclusive: True -> >= threshold, False -> > threshold
    Ignores NaN/Inf. Returns NaN if there are no valid values.
    """
    x = np.asarray(confidence, dtype=float)
    valid = np.isfinite(x)
    if not valid.any():
        return float("nan")
    x = x[valid]
    meets = x >= threshold if inclusive else x > threshold
    return float(meets.mean())

def plot_accuracy_vs_threshold(confidences, labels, num_thresholds=101):
    """
    Plot the proportion of correct predictions for varying confidence thresholds.
    """
    confidences = np.asarray(confidences)
    labels = np.asarray(labels)

    thresholds = np.linspace(0, 1, num_thresholds)
    accuracies = []
    for t in thresholds:
        mask = confidences >= t
        if mask.any():
            accuracies.append(labels[mask].mean())
        else:
            accuracies.append(np.nan)

    plt.figure(figsize=(6, 4))
    plt.plot(thresholds, accuracies)
    plt.xlabel("Confidence Threshold")
    plt.ylabel("Accuracy")
    #plt.ylabel("Accuracy (Proportion Correct)")
    #plt.title("Accuracy vs Confidence Threshold")
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.show()

def label_simulation(
    calibration_dir: str,
    number_of_samples: int,
    context_dir: str,
    similar_num: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Simulate labeling of a calibration dataset by generating similarity scores
    for "near-duplicate" and "dissimilar" question pairs.

    Args:
        calibration_dir (str): Path to the JSON file containing the calibration dataset.
        number_of_samples (int): Number of questions to sample for labeling.
        context_dir (str): Directory path for the RAG persistent database.
        similar_num (int): Maximum index (inclusive) for selecting a similar question.

    Returns:
        Tuple[np.ndarray, np.ndarray]
            - sim_scores: Similarity scores for similar question pairs.
            - diff_scores: Similarity scores for dissimilar question pairs.
    """
    # Load calibration data
    with open(calibration_dir, "r", encoding="utf-8") as f:
        calibration_data = json.load(f)

    # Initialize RAG database and index documents
    rag_db = PersistentRAG(context_dir)
    rag_db.add_documents(calibration_data)

    num_problems = len(calibration_data)
    sample_indices = random.sample(range(num_problems), k=number_of_samples)

    # Prepare arrays to hold similarity scores
    sim_scores = np.zeros(number_of_samples, dtype=float)
    diff_scores = np.zeros(number_of_samples, dtype=float)

    # Generate similarity scores for each sampled question
    for idx, prob_idx in enumerate(sample_indices):
        question = calibration_data[prob_idx]["question"]

        # Similar question: sample index in [0, similar_num]
        sim_idx = random.randint(2, min(similar_num, num_problems - 1))
        sim_scores[idx] = rag_db.kth_similar(question, sim_idx)

        # Dissimilar question: sample index in [similar_num+1, num_problems-1]
        diff_idx = random.randint(similar_num + 1, num_problems - 1)
        diff_scores[idx] = rag_db.kth_similar(question, diff_idx)

    return sim_scores, diff_scores

def gmm_separation_threshold(
    sim_scores: np.ndarray,
    diff_scores: np.ndarray,
    weight: float = 0.5,
    grid_size: int = 1000
) -> float:
    """
    Fit a 2-component GMM on the combined similarities and find the similarity
    threshold at which P(similar_class | x) = weight.

    Args:
        sim_scores (np.ndarray): 1D array of “similar” question similarity scores.
        diff_scores (np.ndarray): 1D array of “dissimilar” question similarity scores.
        weight (float, default=0.5): Desired posterior probability cutoff for the “similar” component. 
            E.g. weight=0.9 means we pick t so that P(similar|x=t)=0.9.
        grid_size (int, default=1000): Number of points to evaluate on the range [min, max] of all scores.

    Returns:
        float: The similarity threshold t satisfying P(similar|x=t)=weight (approximately).
    """

    # 1. Stack data and fit GMM
    X = np.concatenate([sim_scores, diff_scores])[:, None]
    gmm = GaussianMixture(n_components=2, covariance_type="full", random_state=0)
    gmm.fit(X)

    # 2. Determine which component is “similar” (higher mean)
    means = gmm.means_.flatten()
    similar_comp = np.argmax(means)  # index of the component with larger mean
    other_comp = 1 - similar_comp

    # 3. Build a fine grid over the observed range
    x_min, x_max = X.min(), X.max()
    grid = np.linspace(x_min, x_max, grid_size)[:, None]

    # 4. Compute posterior P(z=similar_comp | x) for each grid point responsibilities: shape (n_points, 2)
    post = gmm.predict_proba(grid)  
    sim_post = post[:, similar_comp]

    # 5. Find the grid point closest to your desired weight
    idx = np.argmin(np.abs(sim_post - weight))
    threshold = float(grid[idx])

    return threshold

