import numpy as np
import pandas as pd
from multiprocessing import Pool, cpu_count
import os

from DGP import dgpK5_unb
from test_wrappers_mmd import (
    test_MMD_median, test_mmd_split, test_mmd_fuse
)


def run_single_test(sets, n, p, ratio, param, null_type):
    """
    Run a single test
    
    Parameters:
        sets: Random seed
        n: Sample size
        p: Dimension
        ratio: Sample ratio
        param: Parameter
        null_type: Null hypothesis type
    Returns:
        results: Dictionary containing results from various test methods
    """
    np.random.seed(sets)
    
    # Generate data
    X, Y = dgpK5_unb(n, p, ratio, param, null_type)
    
    # Run all test methods
    # Single-kernel methods default to gaussian; mmdfuse always uses (gaussian, laplace)
    results = {
        'mmd_median': test_MMD_median(X, Y, sets, "gaussian"),
        'mmd_split': test_mmd_split(X, Y, sets, "gaussian"),
        'mmd_fuse': test_mmd_fuse(X, Y, sets),
    }
    
    return results


def compute_power(results_list):
    """
    Calculate the power (rejection rate) for each test method
    Ignores 'na' values when computing the mean
    
    Parameters:
        results_list: List of results, each element is a dictionary
    
    Returns:
        power_dict: Power of each method
    """
    df = pd.DataFrame(results_list)
    # Replace 'na' with NaN, then compute mean ignoring NaN
    df_numeric = df.replace('na', np.nan, inplace=False)
    power_dict = df_numeric.mean().to_dict()
    return power_dict




if __name__ == "__main__":
    output_dir = "case2"
    os.makedirs(output_dir, exist_ok=True)
    
    # Experiment parameters
    ns = [150]
    ps = [200]
    types = ["mean","cov","loc"]
    # No kernel loop: mmd_median/mmd_split use gaussian; mmdfuse uses (gaussian, laplace).
    
    # Theta ranges corresponding to different null_type
    theta_ranges = {
        "mean": [0.8,0.9,1,1.1,1.2],
        "cov": [0.8,0.9,1,1.1,1.2],
        "loc": [0.8,0.9,1,1.1,1.2]
    }
    
    n_repeats = 200  # Number of repetitions
    
    # Initialize results DataFrame
    merge = pd.DataFrame(columns=["n", "p", "param", "null_type", "ratio",
                                  "mmd_median", "mmd_split", "mmd_fuse"])
    
    print("Starting experiment (mmd_simu2: varying signal strength)...")
    
    for i, n in enumerate(ns):
        for k, null_type in enumerate(types):
            thetas = theta_ranges[null_type]
            
            for j, theta in enumerate(thetas):
                print(f"Running: n={n}, p={ps[0]}, type={null_type}, theta={theta}")
                    
                    # Run multiple experiments in parallel
                with Pool(processes=min(int(cpu_count() * 0.7), 60)) as pool:
                    results = pool.starmap(
                        run_single_test,
                        [(sets, n, ps[0], 0.5, theta, null_type)
                         for sets in range(1, n_repeats + 1)]
                    )
                    
                    # Calculate power
                    power_dict = compute_power(results)
                    
                    # Save to merge
                merge.loc[len(merge)] = [
                    n, ps[0], theta, null_type, 0.5,
                    power_dict["mmd_median"], power_dict["mmd_split"], power_dict["mmd_fuse"],
                ]

                # Save intermediate results
                results_df = pd.DataFrame(results)
                results_df.to_csv(
                    f"{output_dir}/results_simu2_{n}_{ps[0]}_{theta}_{null_type}_{0.5}(mmd).csv",
                    index=False,
                )
    
    # Save merged results
    merge.to_csv(f"{output_dir}/merge_power_simu2(mmd).csv", index=False)
    print(f"\nResults saved to {output_dir}/merge_power_simu2(mmd).csv")
