import os
import sys
import yaml
import argparse
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm

# Ensure project root is in path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from experiments.simulations import run_customized_fraction_sweep
from utils.io import load_experiment_from_config


def run_customized_multi_round_sweep(config_path: str):
    # Load YAML config
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    experiment_name = config.get("experiment", {}).get("name", "experiment").lower().replace(" ", "_")

    # Extract runner settings
    runner_cfg = config.get("runner", {})
    modes = runner_cfg.get("modes", ["proxy"])  # Accepts: ["proxy", "trainvalid"]
    models = runner_cfg.get("models", ["lasso"])
    fractions = runner_cfg.get("fractions", [0.2, 0.4, 0.6, 0.8, 1.0])
    if isinstance(fractions, str):
        fractions = eval(fractions, {"np": np})

    seed = runner_cfg.get("seed", 101)
    plot = runner_cfg.get("plot", False)
    verbose = runner_cfg.get("verbose", True)
    output_dir = runner_cfg.get("output_dir", f"outputs/custom_sweep/{experiment_name}")
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    n_repeats = runner_cfg.get("n_repeats", 1)

    all_dfs = []

    for repeat in range(n_repeats):
        print(f"\n[⏱] Running repeat {repeat + 1} of {n_repeats}...")
        seed_i = seed + repeat
        experiment = load_experiment_from_config(config_path, seed=seed_i)

        for mode in modes:
            for model_type in models:
                print(f"\n>>> Running {mode} sweep with {model_type} model...")

                # Output paths
                output_csv = os.path.join(output_dir, f"customsweep_{mode}_{model_type}_rep{repeat+1}_{timestamp}.csv")
                plot_dir = os.path.join(output_dir, f"plots_{mode}_{model_type}_rep{repeat+1}_{timestamp}") if plot else None
                config_copy_path = os.path.join(output_dir, f"config_copy_{mode}_{model_type}_rep{repeat+1}_{timestamp}.yaml")

                # Save config for reproducibility
                with open(config_copy_path, "w") as f_out:
                    yaml.dump(config, f_out)

                # Run simulation sweep
                df = run_customized_fraction_sweep(
                    base_experiment=experiment,
                    config=config,
                    fraction_values=fractions,
                    mode=mode,
                    model_type=model_type,
                    seed=seed_i,
                    output_csv=output_csv,
                    plot=plot,
                    plot_dir=plot_dir,
                    verbose=verbose
                )
                df["repeat"] = repeat + 1
                df["mode"] = mode
                df["model"] = model_type
                all_dfs.append(df)

    if n_repeats > 1:
        df_all = pd.concat(all_dfs, ignore_index=True)
        summary_path = os.path.join(output_dir, f"summary_all_{timestamp}.csv")
        df_all.to_csv(summary_path, index=False)
        print(f"\n[✓] All rounds completed. Summary saved to: {summary_path}")
    else:
        df_all = all_dfs[0]

    return df_all


def main():
    parser = argparse.ArgumentParser(description="Run multi-round customized sweep over proxy/trainvalid fractions.")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    args = parser.parse_args()
    run_customized_multi_round_sweep(args.config)


if __name__ == "__main__":
    main()