import pandas as pd
from rag2_retrieval import load_faiss_index, retrieve_rag2_facts
from utils import get_embeddings, cosine_similarity, normalize_text, call_llm
from cache import CacheManager
from config import THRESHOLD_RAG1, THRESHOLD_RAG2, TOP_K_RAG2

# Load FAISS vectorstore once (path centralized if needed)
vectorstore = load_faiss_index("faiss_index")

# -----------------------------
# Load RAG1 Excel
# -----------------------------
def load_rag1_excel(excel_path):
    """
    Reads RAG1 Excel with columns: QNumber, Question, Answer, Marks
    Returns a dict: normalized_question_text -> {"answer": text, "marks": int}
    """
    df = pd.read_excel(excel_path)
    rag1_data = {}
    for _, row in df.iterrows():
        question = normalize_text(str(row['Question']))
        answer = normalize_text(str(row['Answer']))
        marks = int(row['Marks'])  # Get the marks from CSV
        rag1_data[question] = {"answer": answer, "marks": marks}
    return rag1_data


# -----------------------------
# Score using embeddings only
# -----------------------------
def score_with_similarity(student_answer, reference_texts):
    """
    Returns: max similarity score (0-1) and top matching facts
    """
    student_emb = get_embeddings([normalize_text(student_answer)])[0]
    reference_embs = get_embeddings([normalize_text(t) for t in reference_texts])

    scores = [cosine_similarity(student_emb, ref_emb) for ref_emb in reference_embs]
    if not scores:
        return 0.0, []
    max_idx = scores.index(max(scores))
    return scores[max_idx], [reference_texts[max_idx]]

# -----------------------------
# Score a single answer
# -----------------------------
def score_student_answer(question_text, student_answer, rag1_data, cache_manager, vectorstore=None):
    normalized_question = normalize_text(question_text)

    # Get the answer and marks for this question
    rag1_info = rag1_data.get(normalized_question, {"answer": "", "marks": 10})
    rag1_answer = rag1_info["answer"]
    total_marks = rag1_info["marks"]  # Use actual marks from CSV
    
    # Step 1: Compare with RAG1 answer
    score, top_facts = score_with_similarity(student_answer, [rag1_answer])

    # Step 2: If below threshold, check cold cache
    if score < THRESHOLD_RAG1:
        cache_facts = cache_manager.get_cache(question_text)["cold"]
        if cache_facts:
            fact_texts = [f["fact"] for f in cache_facts]
            score_cache, top_cache_facts = score_with_similarity(student_answer, fact_texts)
            if score_cache > score:
                score = score_cache
                top_facts = top_cache_facts
            cache_manager.promote_to_hot(question_text)

    # Step 3: If still below threshold, fallback RAG2
    if score < THRESHOLD_RAG2:
        new_facts = retrieve_rag2_facts(question_text, vectorstore, k=TOP_K_RAG2)
        if new_facts:
            embeddings = get_embeddings(new_facts)
            new_facts_with_embeddings = [{"fact": fact_text, "embedding": emb} for fact_text, emb in zip(new_facts, embeddings)]
            cache_manager.update_cache(question_text, new_facts_with_embeddings)

            score_rag2, top_rag2_facts = score_with_similarity(student_answer, new_facts)
            if score_rag2 > score:
                score = score_rag2
                top_facts = top_rag2_facts

    # Format facts for prompt
    facts_text_for_prompt = "\n".join(f"- {fact}" for fact in top_facts) if top_facts else "None"

    # Create prompt with CORRECT total marks
    prompt = f"""
You are a skilled grader assessing a student's answer against a question worth {total_marks} marks.

Question: {question_text}
Student Answer: {student_answer}
Reference Facts Considered:
{facts_text_for_prompt}

Based on the evidence from these facts, assign marks out of {total_marks}. The marks must be a whole number.

Provide a detailed, point-wise explanation including:
- Which points the student answered well
- Which points are missing or inadequately explained  
- Why the assigned score reflects the student's performance
- How the student could improve to get full marks

The similarity score suggests awarding around {round(score * total_marks)} out of {total_marks}.

Your task:
1. Decide the final awarded marks (an integer from 0 to {total_marks}).
2. Always output in the format: "Marks Awarded: X/{total_marks}"

Output your response with the awarded marks clearly stated at the start, followed by the reasoning.
"""

    llm_reasoning = call_llm(prompt)
    return score, llm_reasoning  # Return only LLM reasoning



