import sys
from pathlib import Path
_src = Path(__file__).resolve().parent.parent
if str(_src) not in sys.path:
    sys.path.insert(0, str(_src))
from load_dataset import get_overtonbench_data
import helper_functions

import argparse
import math
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity


def main():
    parser = argparse.ArgumentParser(description="Semantic similarity and mean-of-others baselines.")
    parser.add_argument("--source", default=None,
                        help="Question source split when loading from HF: full (default), modelslant, or prism. Output filename gets _modelslant/_prism suffix when set.")
    parser.add_argument("--data", default=None,
                        help="Path to CSV to use instead of Hugging Face (overrides DATASET in .env if set). Same schema as OvertonBench.")
    parser.add_argument("--n_rows", type=int, default=None,
                        help="If set, run on a random sample of this many rows (for quick testing).")
    args = parser.parse_args()

    helper_functions.set_data_options(path=args.data, source_split=args.source)
    df = get_overtonbench_data(path=args.data, source_split=args.source)
    if args.n_rows is not None:
        df = df.sample(min(args.n_rows, len(df)), random_state=42)

    # Load sentence embedding model
    model = SentenceTransformer("all-MiniLM-L6-v2")

    # Precompute embeddings for all rows
    df["embedding"] = list(model.encode(df["llm_response"].tolist(), convert_to_tensor=False))

    sem_sim_ratings = []
    mean_of_other_ratings = []

    # Loop over each row in full dataset
    for idx, row in df.iterrows():
        # Candidate pool: same user & question_id, excluding this exact row
        candidates = df[
            (df["user"] == row["user"]) &
            (df["question_id"] == row["question_id"]) &
            (df.index != idx)
        ]

        if candidates.empty:
            sem_sim_ratings.append(np.nan)
            mean_of_other_ratings.append(np.nan)
            continue

        # --- Baseline 1: semantic similarity baseline ---
        sims = cosine_similarity(
            [row["embedding"]],
            list(candidates["embedding"])
        )[0]
        best_idx = candidates.iloc[np.argmax(sims)].name
        pred_sem_sim = df.loc[best_idx, "representation_rating"]
        sem_sim_ratings.append(pred_sem_sim)

        # --- Baseline 2: average of other responses ---
        mean_rating = candidates["representation_rating"].mean()
        pred_mean_of_others = math.ceil(mean_rating)  # always rounds up to nearest int
        mean_of_other_ratings.append(pred_mean_of_others)

    # Save predictions & differences
    df["sem_sim_avg"] = sem_sim_ratings
    df["sem_sim_diff"] = df["sem_sim_avg"] - df["representation_rating"]

    df["mean_of_others_avg"] = mean_of_other_ratings
    df["mean_of_others_diff"] = df["mean_of_others_avg"] - df["representation_rating"]

    # Output path: _modelslant/_prism for HF splits, _custom for --data; _N when --n_rows set
    if args.data:
        out_name = "baselines_rounded_custom"
    else:
        source = (args.source or "full").strip().lower()
        out_name = "baselines_rounded" if source == "full" else f"baselines_rounded_{source}"
    if args.n_rows is not None:
        out_name = f"{out_name}_{args.n_rows}"
    out_path = f"outputs/predictions/{out_name}.csv"
    df.to_csv(out_path, index=False)

    print(f"Baselines saved to {out_path}")


if __name__ == "__main__":
    main()