import os
import sys
import json
import pandas as pd
from collections import defaultdict

# Set up project root and import paths
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from scripts.run_endowment_generator import main as run_endowment_generator
from scripts.run_regression_experiment import run_regression_experiment
from modules.response_converter import Responses, ResponseCleaner, ResponseUtils
from modules.survey_converter import Survey
from scripts.run_full_pipeline import generate_regression_config

def stringify_keys(d):
    return {
        "+".join(k) if isinstance(k, (tuple, list)) else str(k): v
        for k, v in d.items()
    }

def stringify_mode(mode):
    return "+".join(mode) if isinstance(mode, (tuple, list)) else str(mode)

def run_full_pipeline_and_collect_metrics(config_path: str, output_dir: str):
    """
    Runs full endowment generation and regression pipeline, collects metrics.

    Returns:
        results_dict (dict): All metrics and traces.
        results_df (pd.DataFrame): One-row dataframe of selected summary metrics.
    """
    # Run endowment generator
    generator, endo_results = run_endowment_generator(config_path, output_dir=output_dir, return_outputs=True)

    regression_config_path = os.path.join(output_dir, "regression_config.yaml")
    generate_regression_config(endo_results["config"], output_dir, regression_config_path)

    # Load responses & survey for validation
    responses = generator.responses  # default = 'answer'
    responses_code = generator.responses_code  # ensure 'code' mapping exists
    survey = responses.survey  # type: Survey

    # --- Question counts ---
    question_counts = {
        "n_questions_total": len(survey.get_questions()),
        "n_questions_train": len(survey.get_questions_by_split("train")),
        "n_questions_valid": len(survey.get_questions_by_split("valid")),
        "n_questions_test": len(survey.get_questions_by_split("test")),
    }

    # --- Effective (valid) questions using ResponseUtils ---
    valid_counts = {}
    for split in ["train", "valid", "test"]:
        code_resp = responses_code
        answer_resp = responses

        problematic_answers, _, _, _, _ = ResponseUtils.analyze_missing_mappings(
            code_responses=code_resp,
            answer_responses=answer_resp,
            split=split,
            verbose=False
        )
        qids_to_remove = list(problematic_answers.keys())
        df_split = responses.get_matrix_by_split(split, dropna=True)
        df_clean = df_split.drop(index=qids_to_remove, errors="ignore")
        valid_counts[f"n_valid_questions_{split}"] = df_clean.shape[0]
        valid_counts[f"n_valid_eids_{split}"] = ResponseUtils.count_valid_endowments(
            responses_code, split=split
        )
    
    valid_counts["n_valid_eids_all"] = ResponseUtils.count_valid_endowments(
        responses_code, split=None
    )

    # --- Entropy trajectories ---
    entropy_traj = generator.tracker.get_entropy_trajectories()

    # --- Run regression ---
    experiment, df_agg, df_mode, all_results = run_regression_experiment(regression_config_path, output_dir, return_all=True)

    # --- Extract model metrics ---
    lasso_metrics = next((r for r in all_results if r["name"] == "lasso"), {})
    enet_metrics = next((r for r in all_results if r["name"] == "elasticnet"), {})

    # --- Package full results dict ---
    results_dict = {
        "output_dir": output_dir,
        **endo_results,
        **question_counts,
        **valid_counts,
        "entropy_trajectory": entropy_traj["aggregate"],
        "mode_entropy_trajectories": stringify_keys(entropy_traj["by_mode"]),
        "question_entropy_trajectories": entropy_traj["by_question"],
        "fit_metrics": {
            "lasso": {
                "trainval_mse": lasso_metrics.get("prediction_metrics", {}).get("trainval_mse"),
                "trainval_r2": lasso_metrics.get("prediction_metrics", {}).get("trainval_r2"),
                "test_mse": lasso_metrics.get("prediction_metrics", {}).get("test_mse"),
                "test_r2": lasso_metrics.get("prediction_metrics", {}).get("test_r2"),
                "best_alpha": lasso_metrics.get("hyperparams", {}).get("alpha"),
                "num_agents_selected": sum(
                    m.get("num_selected", 0) for m in lasso_metrics.get("mode_summary", {}).values()
                ),
            },
            "elastic_net": {
                "trainval_mse": enet_metrics.get("prediction_metrics", {}).get("trainval_mse"),
                "trainval_r2": enet_metrics.get("prediction_metrics", {}).get("trainval_r2"),
                "test_mse": enet_metrics.get("prediction_metrics", {}).get("test_mse"),
                "test_r2": enet_metrics.get("prediction_metrics", {}).get("test_r2"),
                "best_alpha": enet_metrics.get("hyperparams", {}).get("alpha"),
                "best_l1_ratio": enet_metrics.get("hyperparams", {}).get("l1_ratio"),
                "num_agents_selected": sum(
                    m.get("num_selected", 0) for m in enet_metrics.get("mode_summary", {}).values()
                ),
            }
        },
    }

    # --- Build one-row summary DataFrame ---
    summary = {
        "output_dir": output_dir,
        **question_counts,
        **{k: valid_counts[k] for k in valid_counts},
        "final_entropy": endo_results.get("final_entropy"),
        "top_mode": stringify_mode(endo_results.get("top_mode")),
        "top_mode_entropy": endo_results.get("top_mode_entropy"),
        "lowest_mode": stringify_mode(endo_results.get("lowest_mode")),
        "lowest_mode_entropy": endo_results.get("lowest_mode_entropy"),
    }

    for result in [lasso_metrics, enet_metrics]:
        name = result.get("name")
        if name not in ["lasso", "elasticnet"]:
            continue
        prefix = name + "_"
        metrics = result.get("prediction_metrics", {})
        hyper = result.get("hyperparams", {})
        mode_summary = result.get("mode_summary", {})
        summary.update({
            prefix + "trainval_mse": metrics.get("trainval_mse"),
            prefix + "trainval_r2": metrics.get("trainval_r2"),
            prefix + "test_mse": metrics.get("test_mse"),
            prefix + "test_r2": metrics.get("test_r2"),
            prefix + "best_alpha": hyper.get("alpha"),
            prefix + "best_l1_ratio": hyper.get("l1_ratio"),
            prefix + "num_agents_selected": sum(m.get("num_selected", 0) for m in mode_summary.values()),
        })

    results_df = pd.DataFrame([summary])
    return results_dict, results_df
