import numpy as np
from scipy import stats
import pandas as pd

def get_p_values(list_number_samples, heldout_train=None, heldout_val=None, ranks=None, max_rank=None, seed=42, repeatitions=5):
    p_value_dict = {}
    for repeat in range(repeatitions):
        if heldout_train is not None and heldout_val is not None: 
            order_train = np.random.permutation(len(heldout_train))
            if len(heldout_train) != len(heldout_val):
                order_val = np.random.permutation(len(heldout_val))
            else:
                order_val = np.random.permutation(len(heldout_val))
        if ranks is not None:
            order_ranks = np.random.permutation(len(ranks))
        
        for num_samples in list_number_samples:
            if heldout_train is not None and heldout_val is not None: 
                # different means
                heldout_train_curr = heldout_train[order_train[:num_samples]]
                heldout_val_curr = heldout_val[order_val[:num_samples]]
                p_value_dict.setdefault('num_samples', []).append(len(heldout_train_curr))
                                            
                stat, p_value = stats.ttest_ind(
                    heldout_val_curr, 
                    heldout_train_curr, 
                    alternative='less',
                    equal_var=True
                )
                p_value_dict.setdefault('t-test', []).append(p_value)

            if ranks is not None: # not uniform
                heldout_ranks = np.array(ranks)[order_ranks[:num_samples]]
                if heldout_train is None or heldout_val is None:
                    p_value_dict.setdefault('num_samples', []).append(len(heldout_ranks))
                p_value = stats.kstest(heldout_ranks/max_rank, stats.uniform.cdf, alternative='greater').pvalue
                p_value_dict.setdefault('ks-test-uniform', []).append(p_value)
    return p_value_dict

def rank_candidates(y_candidates, probs):
    assert len(y_candidates) == len(probs), f"Length mismatch: {len(y_candidates)} {len(probs)}"
    targets = (y_candidates == 1)
    ranking = [int(a) for a in (np.sum(probs[targets][:, None] <= probs[~targets][None, :], axis=-1).astype(int))]
    max_rank = int((~targets).sum())
    return  ranking, max_rank