import sys
from pathlib import Path
_src = Path(__file__).resolve().parent.parent
if str(_src) not in sys.path:
    sys.path.insert(0, str(_src))
from load_dataset import get_overtonbench_data

import pandas as pd
import prompts

_ratings_df = None
_qa_df = None
_data_configured = False
_data_path = None
_data_source_split = None


def set_data_options(path=None, source_split=None):
    """Configure which dataset to load when using Hugging Face or a custom CSV.
    Call this from entry-point scripts (e.g. prediction.py) so lazy-loaded
    ratings/qa use the same data as the main run. path overrides DATASET env;
    source_split is used when loading from HF (e.g. 'full', 'modelslant', 'prism').
    """
    global _data_configured, _data_path, _data_source_split
    _data_configured = True
    _data_path = path
    _data_source_split = source_split


def _get_ratings_and_qa():
    """Lazy-load ratings and qa DataFrames. Uses set_data_options() if called by prediction (or other entry point)."""
    global _ratings_df, _qa_df
    if _ratings_df is None:
        if _data_configured:
            _ratings_df = get_overtonbench_data(path=_data_path, source_split=_data_source_split)
        else:
            _ratings_df = get_overtonbench_data()
        _qa_df = _ratings_df.drop_duplicates(subset=["user", "question_id"]).copy()
    return _ratings_df, _qa_df


# format prompt with few-shot examples and proper fallbacks
def format_prompt(prompt_type: str, row: pd.Series) -> str:
    """
    Format weighted prediction prompt with proper fallbacks + optional few-shot QAs and ratings
    """
    base_args = {
        'question': row.get('question', ''),
        'llm_response': row.get('llm_response', ''),
        'freeresponse': row.get('freeresponse', ''),
        'age': row.get('Age', row.get('age', '')),
        'gender': row.get('Sex', row.get('gender', '')),
        'race': row.get('Ethnicity simplified', row.get('race', '')),
        'political': row.get('U.s. political affiliation', row.get('political', '')),
        'selection_text': row.get('selection_text', ''),
        'dimension_name': row.get('dimension_name', ''),
        'dimension_value': row.get('dimension_value', ''),
        
        # default empty few-shot fields
        'question_1': '', 'freeresponse_1': '',
        'question_2': '', 'freeresponse_2': '',
        'llm_response_1': '', 'rating_1': '',
        'llm_response_2': '', 'rating_2': '',
        'llm_response_3': '', 'rating_3': ''
    }

    # Few-Shot Retrieval
    MANYSHOT_TYPES = [
        "freeresponse_fewshot",
        "fewshot"
    ]

    if prompt_type in MANYSHOT_TYPES:
        ratings_df, qa_df = _get_ratings_and_qa()
        user_id = row["user"]
        question_id = row["question_id"]

        # Get fixed 2 QA examples
        user_rows = qa_df[qa_df["user"] == user_id]
        user_rows = user_rows[user_rows["question_id"] != question_id]
        qa_examples = user_rows.sample(n=2, random_state=42) if len(user_rows) > 2 else user_rows

        qa_list = qa_examples.to_dict(orient="records")
        while len(qa_list) < 2:
            qa_list.append({"question": "", "freeresponse": ""})

        base_args.update({
            'question_1': qa_list[0].get("question", ""),
            'freeresponse_1': qa_list[0].get("freeresponse", ""),
            'question_2': qa_list[1].get("question", ""),
            'freeresponse_2': qa_list[1].get("freeresponse", "")
        })

        # Ratings for SAME question (ratings_df from _get_ratings_and_qa above)
        same_q_rows = ratings_df[
            (ratings_df["user"] == user_id) &
            (ratings_df["question_id"] == question_id) &
            (ratings_df["llm_response"] != row["llm_response"])
        ]
        same_q_lines = []
        for idx, ex in enumerate(same_q_rows.itertuples(index=False), start=1):
            same_q_lines.append(f"LLM Response {idx}: {ex.llm_response}")
            same_q_lines.append(f"Rating {idx}: {ex.representation_rating}")
        examples_same_q = "\n".join(same_q_lines) if same_q_lines else "None"

        # Ratings for OTHER questions (ratings_df from _get_ratings_and_qa above)
        other_q_rows = ratings_df[
            (ratings_df["user"] == user_id) &
            (ratings_df["question_id"] != question_id)
        ]
        other_q_lines = []
        for idx, ex in enumerate(other_q_rows.itertuples(index=False), start=1):
            q_text = qa_df.loc[qa_df["question_id"] == ex.question_id, "question"].values
            q_text = q_text[0] if len(q_text) > 0 else ""
            other_q_lines.append(f"Question {idx}: {q_text}")
            other_q_lines.append(f"LLM Response {idx}: {ex.llm_response}")
            other_q_lines.append(f"Rating {idx}: {ex.representation_rating}")
        examples_other_q = "\n".join(other_q_lines) if other_q_lines else "None"

        # Update placeholders
        base_args.update({
            "examples_same_q": examples_same_q,
            "examples_other_q": examples_other_q
        })
    
    try:
        prompt_template = getattr(prompts, f"prompt_{prompt_type}")
        return prompt_template.format(**base_args)
    except AttributeError:
        raise ValueError(f"Invalid prompt type: {prompt_type}")