# --- CLI and Bootstrapping ---
import argparse
import yaml
import os
import sys
import pandas as pd
import numpy as np

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.endowment_manager import ActiveEndowments
from modules.response_converter import Responses, BinaryExtendedResponses, ResponseUtils
from modules.aggregate_responses import AggregateResponses
from experiments.experiments import EmpiricalExperiment
from experiments.utils import load_model


def load_experiment_and_model(config: dict):
    """
    Given a config dictionary, load survey, responses, and train model if needed.
    """
    paths = config["paths"]
    survey = Survey(csv_path=paths['survey_csv'], config_path=paths['survey_yaml'])
    survey_bin = BinaryExtendedSurvey.from_survey(survey)
    responses = Responses(source=paths['responses_csv'], survey=survey, output_format='answer')
    responses_bin = BinaryExtendedResponses(source=paths['responses_csv'], survey=survey_bin, output_format='code')
    endowments = ActiveEndowments.load(path=paths['endowments_csv'])
    endowments.assign_roles()
    aggregate = AggregateResponses(survey=survey_bin, json_path=paths['aggregate_json'])
    experiment = EmpiricalExperiment(
        responses=responses_bin,
        survey=survey_bin,
        endowments=endowments,
        aggregate_stats=aggregate.get_all_binary(),
        filter_binary=True,
        drop_na=True
    )
    lasso_model_path = config["model_paths"].get('lasso')
    model = load_model(lasso_model_path)

    return {
        "model": model,
        "experiment": experiment,
    }


def export_predictions_csv(model, experiment, output_path="predictions_analysis.csv"):
    """
    Export ground truth vs predictions with error analysis.
    """
    rows = []
    
    # Get feature names (agent columns) from model
    feature_names = model.feature_names_
    
    for split in ['train', 'valid', 'test']:
        df_split = experiment.get_dataframe_by_split(split, proxy_only=True)
        
        if df_split.empty:
            continue
        
        for qid in df_split.index:
            # Ground truth
            ground_truth = df_split.loc[qid, 'aggregate']
            
            # Get agent responses for this question - keep as DataFrame
            X_row = df_split.loc[[qid], feature_names]
            
            # Prediction
            prediction = model.predict(X_row)[0]
            
            rows.append({
                'qid': qid,
                'split': split,
                'ground_truth': ground_truth,
                'prediction': prediction,
                'error': prediction - ground_truth,
                'abs_error': abs(prediction - ground_truth)
            })
    
    df = pd.DataFrame(rows)
    df_sorted = df.sort_values('abs_error', ascending=False)
    df_sorted.to_csv(output_path, index=False)
    
    # Print summary
    print(f"Saved {len(df)} predictions to {output_path}")
    print(f"\n=== Summary by Split ===")
    for split in ['train', 'valid', 'test']:
        split_df = df[df['split'] == split]
        if not split_df.empty:
            rmse = np.sqrt((split_df['error'] ** 2).mean())
            print(f"{split}: {len(split_df)} questions, RMSE = {rmse:.4f}")
    
    print(f"\n=== Top 15 Highest Errors ===")
    print(df_sorted.head(15).to_string(index=False))
    
    return df_sorted


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export predictions to CSV for analysis.")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config.")
    parser.add_argument("--output", type=str, default="predictions_analysis.csv", help="Output CSV path.")
    args = parser.parse_args()

    with open(args.config, "r") as f:
        config = yaml.safe_load(f)

    pack = load_experiment_and_model(config)
    df = export_predictions_csv(pack["model"], pack["experiment"], args.output)