import json
import os
import random
from typing import List, Dict, Any
import litellm
from pydantic import BaseModel, Field
from tqdm import tqdm
from utils.litellm_router_models import MODEL_LIST

LLM_MODEL_NAME = "openai/gpt-4.1-2025-04-14"

random.seed(42)
router = litellm.Router(model_list=MODEL_LIST, num_retries=3, retry_after=5)

DATA_DIR = "/home/yixizeng/rag-robust-eval-benchmark/qa_datasets_new"
CHUNKS_FILE = "/home/yixizeng/rag-robust-eval-benchmark/chunks.json"
OUTPUT_DIR = "/home/yixizeng/rag-robust-eval-benchmark/qa_datasets_new_multiple_choice"

class IncorrectOptions(BaseModel):
    options: List[str] = Field(
        min_items=4,
        max_items=4
    )

def load_chunks(chunks_file: str) -> Dict[str, str]:
    with open(chunks_file, 'r', encoding='utf-8') as f:
        chunks = json.load(f)
    return {chunk['chunk_id']: chunk['text'] for chunk in chunks}

def generate_incorrect_options(question: str, correct_answer: str, chunk_texts: List[str], 
                             metadata: Dict[str, Any], num_options: int = 4) -> List[str]:
    context = "\n\n".join(chunk_texts)
    prompt = f"""
Task:
Generate exactly {num_options} distractor answer options for a multiple-choice item.

Inputs
Question: {question}
Correct Answer: {correct_answer}
Metadata: {metadata}
Context: {context}

Requirements (must satisfy ALL)
1. Plausible yet Wrong: Each option should be believable given the question/context/metadata but factually or logically incorrect.
2. Distinctiveness: 
   - No option may be a paraphrase, subset, or superset of the correct answer or another option.
   - Avoid sharing uncommon key terms or unique numbers with the correct answer.
3. Context Alignment: Use concepts, entities, time periods, or formats that appear or are implied in the context/metadata.
4. Parallel Form & Length: 
   - Match the grammatical form and answer type (e.g., noun phrase, year, formula) of the correct answer.
   - Keep length within 20% of the correct answer (unless it's a number/date).
5. Difficulty Calibration:
   - Avoid trivially wrong distractors (e.g., "All of the above," jokes, obviously unrelated topics).
   - Prefer "near-miss" errors: common misconceptions, off-by-one numbers, closely related entities, partially correct but ultimately wrong statements.
6. No Clues or Negations:
   - Do not include hedges like "probably," "maybe," or "not."
   - Do not include explanatory clauses ("because...", "due to...").
7. Forbidden Content:
   - No "None of the above/All of the above."
   - No duplicates, contradictions inside a single option, or offensive content.
8. Count Integrity: Return exactly {num_options} options.

Output Format
Return ONLY the options with no extra text."""
    response = router.completion(
        model=LLM_MODEL_NAME,
        messages=[{"role": "user", "content": prompt}],
        response_format=IncorrectOptions,
        temperature=0.7,
        max_tokens=300
    )
    
    structured_response = json.loads(response.choices[0].message.content)
    return structured_response['options']
            

def convert_to_multiple_choice(data_item: Dict[str, Any], chunks_map: Dict[str, str]) -> Dict[str, Any]:
    question = data_item['question']
    correct_answer = data_item['answer']
    chunk_ids = data_item['answer_chunk_ids']
    metadata = data_item.get('metadata', {})
    
    chunk_texts = [chunks_map.get(chunk_id, "") for chunk_id in chunk_ids if chunk_id in chunks_map]
    chunk_texts = [text for text in chunk_texts if text]
    
    # Generate incorrect options using metadata context
    incorrect_options = generate_incorrect_options(question, correct_answer, chunk_texts, metadata)
    
    # Create all options (ensure exactly 5 options)
    all_options = [correct_answer] + incorrect_options[:4]
    if len(all_options) < 5:
        all_options.extend([f"Option {i}" for i in range(len(all_options), 5)])
    
    # Shuffle to ensure equal chance for correct answer position
    random.shuffle(all_options)
    
    # Find correct answer index
    correct_index = all_options.index(correct_answer)
    
    # Create the new data structure preserving all original fields
    mc_data = data_item.copy()
    
    # Add multiple choice specific fields
    mc_data.update({
        "options": {
            "A": all_options[0],
            "B": all_options[1],
            "C": all_options[2],
            "D": all_options[3],
            "E": all_options[4]
        },
        "correct_answer": chr(65 + correct_index)
    })
    
    return mc_data

def process_directory(data_dir: str, chunks_file: str, output_dir: str):
    print("Loading chunks...")
    chunks_map = load_chunks(chunks_file)
    os.makedirs(output_dir, exist_ok=True)
    
    json_files = [f for f in os.listdir(data_dir) if f.endswith('.json')]
    total_questions = 0
    
    for json_file in tqdm(json_files, desc="Processing JSON files"):
        print(f"Processing {json_file}...")
        input_path = os.path.join(data_dir, json_file)
        output_path = os.path.join(output_dir, json_file)
        
        # Load data from input file
        with open(input_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        mc_questions = []
        
        if isinstance(data, list):
            for item in tqdm(data, desc=f"Converting questions in {json_file}", leave=False):
                mc_question = convert_to_multiple_choice(item, chunks_map)
                mc_questions.append(mc_question)
        else:
            mc_question = convert_to_multiple_choice(data, chunks_map)
            mc_questions.append(mc_question)
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(mc_questions, f, indent=2, ensure_ascii=False)
        
        print(f"Converted {len(mc_questions)} questions")
        print(f"Saved to {output_path}")
        total_questions += len(mc_questions)
    
    print(f"\nTotal: Converted {total_questions} questions from {len(json_files)} files")
    print(f"All output files saved to {output_dir}/")

if __name__ == "__main__":        
    process_directory(DATA_DIR, CHUNKS_FILE, OUTPUT_DIR)