import os
import sys
import json
import yaml
import logging
import argparse
import pandas as pd
from tqdm import tqdm

# Set up project root and import paths
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from scripts.run_full_pipeline_and_collect_metrics import run_full_pipeline_and_collect_metrics


def setup_logger(name="panel_sweep", level=logging.INFO):
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.handlers = []  # Clear existing handlers

    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

    # Console handler
    ch = logging.StreamHandler(sys.stdout)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # Silence external loggers
    logging.getLogger("httpx").setLevel(logging.WARNING)
    logging.getLogger("openai").setLevel(logging.WARNING)

    return logger

logger = setup_logger()

def format_paths(config: dict, wave: str) -> dict:
    """
    Replace {wave} placeholders in path fields within config["paths"].
    Modifies the config in-place and returns it for convenience.
    """
    if "paths" in config:
        for key, value in config["paths"].items():
            if isinstance(value, str):
                config["paths"][key] = value.replace("{wave}", wave)
    return config

def run_panel_sweep(base_config_path: str):
    """
    Runs the full pipeline for multiple waves and repeats, saving metrics.

    All wave IDs and repeat settings are loaded from the YAML config.
    """
    all_results = []

    # Load base config once
    with open(base_config_path, "r") as f:
        base_config = yaml.safe_load(f)

    base_output_dir = base_config["sweep"].get("output_dir", "outputs")
    wave_ids = base_config["sweep"]["waves"]
    num_repeats = base_config["sweep"].get("repeats", 3)
    base_seed = base_config.get("experiment", {}).get("random_seed", 101)

    for repeat in range(1, num_repeats + 1):
        logger.info(f"Starting repeat {repeat}/{num_repeats}")
        repeat_seed = base_seed + repeat

        for wave in tqdm(wave_ids, desc=f"Repeat {repeat} - Sweeping Waves"):
            logger.info(f"Running wave: {wave}, repeat: {repeat}")

            # Define paths
            sweep_name = f"wave_{wave}_repeat_{repeat}"
            output_dir = os.path.join(base_output_dir, "multi_round", sweep_name)
            os.makedirs(output_dir, exist_ok=True)

            # Load config again to modify per run
            with open(base_config_path, "r") as f:
                config = yaml.safe_load(f)
            config["experiment"]["name"] = sweep_name
            config["experiment"]["description"] = f"Panel sweep for wave {wave}, repeat {repeat}"
            config["experiment"]["random_seed"] = repeat_seed

            # Replace all {wave} in paths
            config = format_paths(config, wave)
            
            config["metadata"] = {
                "name": sweep_name,
                "description": f"Generate 300 active endowments for {wave} using GPT agents",
                "version": base_config.get("experiment", {}).get("version", "v1"),
                "seed": repeat_seed,
            }


            # Save updated config
            updated_config_path = os.path.join(output_dir, "config_used.yaml")
            with open(updated_config_path, "w") as f:
                yaml.dump(config, f)

            # Run and collect metrics
            results_dict, results_df = run_full_pipeline_and_collect_metrics(
                updated_config_path, output_dir
            )

            # Save outputs
            results_df.to_csv(os.path.join(output_dir, "aggregate_summary.csv"), index=False)
            with open(os.path.join(output_dir, "aggregate_results.json"), "w") as f:
                json.dump(results_dict, f, indent=2)

            # Store all summary rows for global aggregation
            row = results_df.iloc[0].to_dict()
            row["wave"] = wave
            row["repeat"] = repeat
            all_results.append(row)

    # Save global summaries
    summary_df = pd.DataFrame(all_results)
    summary_df.to_csv(os.path.join(base_output_dir, "all_results.csv"), index=False)
    with open(os.path.join(base_output_dir, "all_results.json"), "w") as f:
        json.dump(all_results, f, indent=2)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to base YAML config")
    args = parser.parse_args()

    run_panel_sweep(
        base_config_path=args.config
    )
