import numpy as np
import pandas as pd
import os

from DGP import dgpK6_unb
from test_wrappers_me import (
    test_ME_raw
)


def run_single_test(sets, n, p, ratio, param, null_type):
    """
    Run a single test
    
    Parameters:
        sets: Random seed
        n: Sample size
        p: Dimension
        ratio: Sample ratio
        param: Parameter
        null_type: Null hypothesis type
    Returns:
        results: Dictionary containing results from various test methods
    """
    np.random.seed(sets)
    
    # Generate data (aligned with simu3 / mmd_simu3)
    X, Y = dgpK6_unb(n, p, ratio, param, null_type)
    
    # Run all test methods
    results = {
        "me_raw": test_ME_raw(X, Y, sets),
    }

    return results


def compute_power(results_list):
    """
    Calculate the power (rejection rate) for each test method
    Ignores 'na' values when computing the mean
    
    Parameters:
        results_list: List of results, each element is a dictionary
    
    Returns:
        power_dict: Power of each method
    """
    df = pd.DataFrame(results_list)
    # Replace 'na' with NaN, then compute mean ignoring NaN
    df_numeric = df.replace('na', np.nan, inplace=False)
    power_dict = df_numeric.mean().to_dict()
    return power_dict




if __name__ == "__main__":
    output_dir = "case3"
    os.makedirs(output_dir, exist_ok=True)

    # Experiment parameters (match simu3.py & mmd_simu3.py)
    ns = [100, 150, 200, 250, 300]
    ps = [10]
    thetas = [0,0.35,0.35,0.25,3]
    types = ["mean", "mean", "cov", "loc", "dstr"]
    ratio = 0.5

    n_repeats = 200  # Number of repetitions

    # Initialize results DataFrame
    merge = pd.DataFrame(
        columns=["n", "p", "param", "null_type", "ratio", "me_raw"]
    )

    print("Starting experiment (me_simu3: varying sample size)...")

    for k, null_type in enumerate(types):
        param = thetas[k]

        for i, n in enumerate(ns):
            print(f"Running: n={n}, p={ps[0]}, type={null_type}")

            # 单线程运行，避免并行
            results = [
                run_single_test(sets, n, ps[0], ratio, param, null_type)
                for sets in range(1, n_repeats + 1)
            ]

            # Calculate power
            power_dict = compute_power(results)

            # Save to merge
            merge.loc[len(merge)] = [
                n,
                ps[0],
                param,
                null_type,
                ratio,
                power_dict["me_raw"],
            ]

            # Save intermediate results
            results_df = pd.DataFrame(results)
            results_df.to_csv(
                f"{output_dir}/results_simu3_{n}_{ps[0]}_{param}_{null_type}_{ratio}(me).csv",
                index=False,
            )

    # Save merged results
    merge.to_csv(f"{output_dir}/merge_power_simu3(me).csv", index=False)
    print(f"\nResults saved to {output_dir}/merge_power_simu3(me).csv")
