# --- CLI and Bootstrapping ---
import argparse
import yaml
import os
import sys

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.endowment_manager import ActiveEndowments
from modules.response_converter import Responses, BinaryExtendedResponses, ResponseUtils
from modules.aggregate_responses import AggregateResponses
from experiments.experiments import EmpiricalExperiment
from experiments.utils import *
import plotly.graph_objs as go

from dash import Dash, dcc, html, Input, Output
# --- Auxiliary modular functions ---

def load_experiment_and_model(config: dict):
    """
    Given a config dictionary, load survey, responses, and train model if needed.

    Returns:
        model: Trained regression model (Lasso, ElasticNet, etc.)
        experiment: Experiment object (for prediction scatter)
        responses: Responses instance (for response distribution)
    """
    # Extract paths, load survey/responses
    paths = config["paths"]
    survey = Survey(csv_path=paths['survey_csv'], config_path=paths['survey_yaml'])
    survey_bin = BinaryExtendedSurvey.from_survey(survey)
    responses = Responses(source=paths['responses_csv'], survey=survey, output_format='answer')
    responses_bin = BinaryExtendedResponses(source=paths['responses_csv'], survey=survey_bin, output_format='code')
    endowments = ActiveEndowments.load(path=paths['endowments_csv'])
    endowments.assign_roles()
    aggregate = AggregateResponses(survey=survey_bin, json_path=paths['aggregate_json'])
    experiment = EmpiricalExperiment(
        responses=responses_bin,
        survey=survey_bin,
        endowments=endowments,
        aggregate_stats=aggregate.get_all_binary(),
        filter_binary=True,
        drop_na=True
    )
    lasso_model_path = config["model_paths"].get('lasso')
    model = load_model(lasso_model_path)
    agent_weights = model.coef_dict_

    agent_aggregate_responses = ResponseUtils.aggregate_weighted_responses(responses, agent_weights)


    experiment_pack = {
        "model": model,
        "feature_names": model.feature_names_,
        "experiment": experiment,
        "responses_bin": responses_bin,
        "responses": responses,
        "human_aggregate_responses": aggregate.raw,
        "agent_aggregate_responses": agent_aggregate_responses
    }

    return experiment_pack

# --- CLI parsing ---
parser = argparse.ArgumentParser(description="Interactive prediction explorer with response distribution drill-down.")
parser.add_argument("--config", type=str, required=True, help="Path to YAML config.")
args = parser.parse_args()

# --- Load config and objects ---
with open(args.config, "r") as f:
    config = yaml.safe_load(f)

experiment_pack  = load_experiment_and_model(config)
model = experiment_pack["model"]
feature_names = experiment_pack["feature_names"]
experiment = experiment_pack["experiment"]
responses_bin = experiment_pack["responses_bin"]
responses = experiment_pack["responses"]
human_aggregate_responses = experiment_pack["human_aggregate_responses"]
agent_aggregate_responses = experiment_pack["agent_aggregate_responses"]

# --- Launch Dash app ---
app = Dash(__name__)
app.layout = html.Div([
    html.H2("Prediction Explorer with Response Drill-Down"),
    dcc.Graph(id="prediction-plot", figure=plot_model_predictions_interactive(model, experiment, config, selected_features= feature_names)),
    html.Div(id="response-distribution")
])

@app.callback(
    Output("response-distribution", "children"),
    Input("prediction-plot", "clickData"),
)
def on_point_click(clickData):
    if not clickData:
        return "Click on a point to view response distribution."
    qid_bin = clickData["points"][0]["text"]
    return dcc.Graph(figure=plot_human_vs_agent_response_distribution_interactive(qid_bin, responses.survey, responses_bin, responses, human_aggregate_responses, agent_aggregate_responses))

    

if __name__ == "__main__":
    app.run(debug=True)