import os
import sys
import argparse
import yaml
import pandas as pd

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from modules.survey_converter import BinaryExtendedSurvey
from modules.aggregate_responses import AggregateResponses
from experiments.utils import (
    _compute_one_minus_jsd_per_qid,
    _compute_emd_per_qid,
    _compute_majority_accuracy_per_qid,
    _mean_over_qids,
)


def evaluate_fewshot_predictions(pred: AggregateResponses,
                                 truth: AggregateResponses,
                                 config: dict):
    survey = pred.survey
    split = config["split_settings"]["test_split"]
    qids = [q["id"] for q in survey.get_questions_by_split(split)]

    pred_dict = pred.raw
    truth_dict = truth.raw

    jsd = _compute_one_minus_jsd_per_qid(pred_dict, truth_dict)
    emd = _compute_emd_per_qid(pred_dict, truth_dict, survey)
    acc = _compute_majority_accuracy_per_qid(pred_dict, truth_dict)

    joint_qids = [q for q in qids if q in jsd and q in emd and q in acc]
    if len(joint_qids) < len(qids):
        print(f"[INFO] Excluding {len(set(qids) - set(joint_qids))} non-overlapping qids from metric eval.")

    return {
        "test_one_minus_jsd": _mean_over_qids(jsd, joint_qids),
        "test_emd": _mean_over_qids(emd, joint_qids),
        "test_accuracy": _mean_over_qids(acc, joint_qids)
    }


def run_compute_metrics(config_path, pred_path, output_dir=None):
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    survey = BinaryExtendedSurvey(csv_path=config["paths"]["survey_csv"],
                                  config_path=config["paths"]["survey_yaml"])
    pred = AggregateResponses(json_path=pred_path, survey=survey)
    truth = AggregateResponses(json_path=config["paths"]["aggregate_json"], survey=survey)

    metrics = evaluate_fewshot_predictions(pred, truth, config)
    print("[RESULT]", metrics)

    if output_dir:
        df = pd.DataFrame([metrics])
        out_path = os.path.join(output_dir, "summary_simple_avg.csv")
        df.to_csv(out_path, index=False)
        print(f"[SAVED] {out_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True)
    parser.add_argument("--pred", required=True)
    parser.add_argument("--output", default=None)
    args = parser.parse_args()

    run_compute_metrics(args.config, args.pred, args.output)