from freqopttest.tst import MeanEmbeddingTest as fot_MeanEmbeddingTest
from freqopttest.data import TSTData as fot_TSTData
import numpy as np


# based on job_met_opt() function
# https://github.com/wittawatj/interpretable-test/blob/master/freqopttest/ex/ex1_power_vs_n.py
def met(X, Y, r, J=10, alpha=0.05):
    """MeanEmbeddingTest with test locations optimized.
    
    Parameters:
        X: numpy array of shape (n, p)
        Y: numpy array of shape (n, p)
        r: random seed
        J: number of test locations
        alpha: significance level
    
    Returns:
        int: 1 if H0 is rejected, 0 otherwise
    """
    # If sample sizes differ, randomly select samples to make them equal
    m = X.shape[0]
    n = Y.shape[0]
    if m != n:
        min_size = min(m, n)
        if min_size < 2:
            raise ValueError("After subsampling, at least one sample must have at least 2 observations")
        np.random.seed(r + 12345)  # Use a different seed offset for subsampling
        if m > n:
            indices = np.random.choice(m, size=n, replace=False)
            X = X[indices]
            m = n
        else:
            indices = np.random.choice(n, size=m, replace=False)
            Y = Y[indices]
            n = m
    data = fot_TSTData(X, Y)
    tr, te = data.subsample(X.shape[0], seed=r+4).split_tr_te(tr_proportion=0.5, seed=r+5)

    met_opt_options = {'n_test_locs': J, 'max_iter': 200, 
            'locs_step_size': 0.1, 'gwidth_step_size': 0.1, 'batch_proportion': 1.0,
            'seed': r+92856, 'tol_fun': 1e-3}
    try:
        test_locs, gwidth, info = fot_MeanEmbeddingTest.optimize_locs_width(tr, alpha, **met_opt_options)
        met_opt = fot_MeanEmbeddingTest(test_locs, gwidth, alpha)
        met_opt_test = met_opt.perform_test(te)
        return int(met_opt_test['h0_rejected'])
    except (ValueError, OverflowError, RuntimeError, Exception) as e:
        # Catch numerical errors (boost::math::rounding_error, overflow, etc.)
        # that can occur with extreme parameter combinations
        return 'na'
