import numpy as np
import pandas as pd
import os

from DGP import dgpK5_unb
from test_wrappers_agginc import (
    test_AGGINC_raw,
    test_MMDAGG_raw
)


def run_single_test(sets, n, p, ratio, param, null_type):
    np.random.seed(sets)
    X, Y = dgpK5_unb(n, p, ratio, param, null_type)
    results = {
        "mmdagg_raw": test_MMDAGG_raw(X, Y, sets),
        "agginc_raw": test_AGGINC_raw(X, Y, sets),
    }
    return results


def compute_power(results_list):
    df = pd.DataFrame(results_list)
    df_numeric = df.replace('na', np.nan, inplace=False)
    power_dict = df_numeric.mean().to_dict()
    return power_dict




if __name__ == "__main__":
    output_dir = "case2"
    os.makedirs(output_dir, exist_ok=True)

    ns = [150]
    ps = [200]
    types = ["mean", "cov", "loc"]
    theta_ranges = {
        "mean": [0.8, 0.9, 1, 1.1, 1.2],
        "cov": [0.8, 0.9, 1, 1.1, 1.2],
        "loc": [0.8, 0.9, 1, 1.1, 1.2],
    }
    n_repeats = 200
    ratio = 0.5

    merge = pd.DataFrame(
        columns=["n", "p", "param", "null_type", "ratio", "mmdagg_raw", "agginc_raw"]
    )

    print("Starting experiment (agginc_simu2: varying signal strength)...")

    for i, n in enumerate(ns):
        for k, null_type in enumerate(types):
            thetas = theta_ranges[null_type]
            for j, theta in enumerate(thetas):
                print(f"Running: n={n}, p={ps[0]}, type={null_type}, theta={theta}")

                results = [
                    run_single_test(sets, n, ps[0], ratio, theta, null_type)
                    for sets in range(1, n_repeats + 1)
                ]

                power_dict = compute_power(results)

                merge.loc[len(merge)] = [
                    n,
                    ps[0],
                    theta,
                    null_type,
                    ratio,
                    power_dict["mmdagg_raw"],
                    power_dict["agginc_raw"],
                ]

                results_df = pd.DataFrame(results)
                results_df.to_csv(
                    f"{output_dir}/results_simu2_{n}_{ps[0]}_{theta}_{null_type}_{ratio}(mmdagg).csv",
                    index=False,
                )

    merge.to_csv(f"{output_dir}/merge_power_simu2(mmdagg).csv", index=False)
    print(f"\nResults saved to {output_dir}/merge_power_simu2(mmdagg).csv")
