import numpy as np
import random
import pandas as pd

def permutation_weighted_pval(cal_scores, cal_weights, test_scores, test_weights, M=500, statistic="max"):
    """
    Monte Carlo estimate of permutation-weighted p-value with flexible statistic.
    
    Parameters:
    - cal_scores: list of calibration scores
    - cal_weights: list of calibration importance weights
    - test_scores: list of test scores (length = k)
    - test_weights: list of test weights (length = k)
    - statistic: one of "max", "min", "mean", or "sum"
    - M: number of Monte Carlo permutations
    
    Returns:
    - estimated p-value
    """
    
    n0 = len(cal_scores)
    k = len(test_scores)
    cal_scores = np.array(cal_scores)
    test_scores = np.array(test_scores)
    cal_weights = np.array(cal_weights)
    test_weights = np.array(test_weights)

    all_scores = np.concatenate([cal_scores, test_scores])
    all_weights = np.concatenate([cal_weights, test_weights])
    
    # Define the statistic function
    if statistic == "max":
        stat = np.max
    elif statistic == "min":
        stat = np.min
    elif statistic == "mean":
        def stat(scores):
            mean = np.mean(scores)
            return mean
    elif statistic == "sum":
        def stat(scores):
            s = np.sum(scores)
            return s
    elif statistic == "rank_sum":
        def stat(scores):
            # Get ranks of all scores (calibration + test) in ascending order
            all_ranks = np.argsort(np.argsort(all_scores)) + 1
            # Return sum of ranks for test scores
            return np.sum(all_ranks[n0:n0+k])
    else:
        raise ValueError(f"Unknown statistic: {statistic}")

    S_obs = stat(test_scores)
    num = 0.0
    den = 0.0

    for _ in range(M):
        perm = np.random.permutation(n0 + k)
        test_idx = perm[n0:]  # permuted test set of size k

        test_scores_perm = all_scores[test_idx]
        test_weights_perm = all_weights[test_idx]
        W_pi = np.prod(test_weights_perm)

        S_pi = stat(test_scores_perm)
        
        den += W_pi
        if S_obs >= S_pi:
            num += W_pi
            
    return num/den

def weighted_permutation_pval(cal_scores, cal_weights, test_scores, test_weights, method, M=500, statistic="max"):
    """
    Compute p-value using weighted permutation sampling.
    
    Parameters:
    - cal_scores: array-like, calibration scores
    - cal_weights: array-like, calibration weights
    - test_scores: array-like, test scores
    - test_weights: array-like, test weights
    - method: str, permutation method ("exp", "gumbel", or "reservoir")
    - M: int, number of Monte Carlo permutations
    - statistic: str, statistic for combining scores ("max", "min", "mean", "sum")
    
    Returns:
    - p-value: float, estimated p-value
    """
    
    n0 = len(cal_scores)
    k = len(test_scores)
    
    # Concatenate all scores and weights
    all_scores = np.concatenate([cal_scores, test_scores])
    all_weights = np.concatenate([cal_weights, test_weights])
        
    # Define the statistic function
    if statistic == "max":
        stat_func = np.max
    elif statistic == "min":
        stat_func = np.min
    elif statistic == "mean":
        stat_func = np.mean
    elif statistic == "sum":
        stat_func = np.sum
    else:
        raise ValueError(f"Unknown statistic: {statistic}")
    
    # Compute the observed statistic
    S_obs = stat_func(test_scores)
    
    # Count permutations where S_pi >= S_obs
    count_exceeding = 0
    
    for _ in range(M):
        # Sample k indices using weighted permutation
        perm_indices = get_weighted_permutation(
            values=np.arange(n0 + k),
            k=k,
            weights=all_weights,
            method=method
        )
        
        # Get permuted test scores
        test_scores_perm = all_scores[perm_indices]
        
        # Compute permuted statistic
        S_pi = stat_func(test_scores_perm)
        
        # Count if permuted statistic exceeds observed
        if S_pi >= S_obs:
            count_exceeding += 1
    
    # Compute p-value
    return count_exceeding / M

def compute_nested_pvalues(opt_samples, cal_scores, cal_weights, statistic="max", 
                         M=500, permutation_method="standard"):
    """
    Compute nested p-values for a sequence of samples.
    
    Parameters:
    - opt_samples: list of dictionaries with 'score' and 'likelihood_ratio' for each sample
    - cal_scores: calibration scores
    - cal_weights: calibration weights
    - statistic: which statistic to use for combining p-values ("min", "max", "mean")
    - M: number of permutations for nested p-value computation
    - permutation_method: method for permutation ("standard")
    
    Returns:
    - list of SMILES in order they were picked
    - list of corresponding nested p-values
    """
    # Randomly shuffle the samples
    random.shuffle(opt_samples)
    
    selected_smiles = []
    nested_pvalues = []
    
    # Process each sample sequentially
    for i, sample in enumerate(opt_samples):
        selected_smiles.append(sample.get('smiles', f'sample_{i}'))
        
        # For the first sample, just compute the individual p-value
        if i == 0:
            p_value = conformal_pvalue_single(
                cal_scores, 
                cal_weights, 
                sample['score'], 
                sample['likelihood_ratio']
            )
            nested_pvalues.append(p_value)
        else:
            # For subsequent samples, use permutation test
            test_scores = [opt_samples[j]['score'] for j in range(i+1)]
            test_weights = [opt_samples[j]['likelihood_ratio'] for j in range(i+1)]
            
            if permutation_method == "standard":
                # Use standard permutation test
                p_value = permutation_weighted_pval(
                    cal_scores=cal_scores,
                    cal_weights=cal_weights,
                    test_scores=test_scores,
                    test_weights=test_weights,
                    M=M,
                    statistic=statistic
                )
            else:
                raise ValueError(f"Unknown permutation method: {permutation_method}")
            
            p_value = min(p_value, nested_pvalues[i-1])
            
            nested_pvalues.append(p_value)
    
    return selected_smiles, nested_pvalues

def sequential_test(calib_scores, calib_weights, test_scores_stream, test_weights_stream, 
                   alphas=[0.05], max_k=None, M=1000, statistic="max", permutation_method="standard"):
    """
    Perform sequential testing with weighted scores.
    
    Parameters:
    - calib_scores: scores from calibration set
    - calib_weights: weights for calibration set
    - test_scores_stream: stream of test scores
    - test_weights_stream: stream of test weights
    - alphas: list of significance levels to test
    - max_k: maximum number of samples to test (None for all)
    - M: number of permutations for nested p-value computation
    - statistic: which statistic to use for combining p-values
    - permutation_method: method for permutation ("standard" or one of "exp", "gumbel", "reservoir")
    
    Returns:
    - rejection_points: dictionary mapping alpha values to their rejection points
    - p_values: list of p-values
    """
    if max_k is None:
        max_k = len(test_scores_stream)
    else:
        max_k = min(max_k, len(test_scores_stream))
    
    test_samples = [
        {'score': s, 'likelihood_ratio': w} 
        for s, w in zip(test_scores_stream, test_weights_stream)
    ]
    
    # Compute p-values once
    _, p_values = compute_nested_pvalues(
        test_samples[:max_k], 
        calib_scores, 
        calib_weights,
        statistic=statistic,
        M=M,
        permutation_method=permutation_method
    )
    
    # Find rejection points for each alpha
    rejection_points = {}
    for alpha in alphas:
        rejection_point = None
        for k, p_value in enumerate(p_values):
            if p_value <= alpha:
                rejection_point = k
                break
        rejection_points[alpha] = rejection_point
    
    return rejection_points, p_values 