import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from openai import OpenAI
import json
from tqdm import tqdm

def load_chunks(doc_path):
    loader = PyPDFLoader(doc_path)
    pages = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    chunks = text_splitter.split_documents(pages)
    return chunks
def create_embedings(open_api_key):
    embeddings = OpenAIEmbeddings(openai_api_key=open_api_key)
    return embeddings
def create_db(chunks,embeddings,chroma_path):
    db_chroma = Chroma.from_documents(chunks, embeddings, persist_directory=chroma_path)
    return db_chroma
    
def extract_context(db,query):
    docs_chroma = db.similarity_search_with_score(query, k=5)
    context_text = "\n\n".join([doc.page_content for doc, _score in docs_chroma])
    return context_text
    
def create_prompt(question: str, context: str) -> str:
    prompt = f"""
    You are a physics expert solving high-school problems.
    Use the following information to help you reason through the problem.

    Context:
    {context}

    Question:
    {question}

    Think step-by-step:
    """
    return prompt
def nebius_chain_of_thought_inference(client, model_name, question: str, context: str) -> str:
    """
    Calls Nebius for chain-of-thought inference with the assembled prompt.
    """
    prompt = create_prompt(question=question, context=context)
    response = client.chat.completions.create(
        model=model_name,
        temperature=0,
        messages=[
            {"role": "system", "content": "You are an expert at solving college level physics problems."},
            {"role": "user", "content": prompt}
        ]
    )
    return response.choices[0].message.content.strip()

def format_question_with_choices(data):
    question = data["question"]
    choices = data["choices"]
    choice_labels = ['A', 'B', 'C', 'D']
    combined = question + "\n"
    for label, choice in zip(choice_labels, choices):
        combined += f"{label}. {choice}\n"
    return combined.strip()

doc_path="physics-formulas-for-neet-2023.pdf"

chroma_path = "./chroma_db" 

OPENAI_API_KEY = ""
nebius_key = ""

chunks=load_chunks(doc_path=doc_path)
embedings=create_embedings(open_api_key=OPENAI_API_KEY)
db=create_db(chunks,embedings,chroma_path)
client = OpenAI(api_key=nebius_key, base_url="https://api.studio.nebius.com/v1/")

data_path="mmlu_col_physics.json"
with open(data_path,"r") as f:
    data=json.load(f)
    
model_name = ""

final_inference_data=[]
for i, entry in tqdm(enumerate(data), total=len(data), desc="Running Inference"):
    final_data = {}
    question = entry["question"]
    choices = entry["choices"]
    answer_index = entry["answer"]
    subject = entry.get("subject", "")
    final_question = format_question_with_choices(entry)

    # RAG: retrieve top chunks
    context_text = extract_context(db, question)

    # Nebius chain-of-thought inference
    cot_response = nebius_chain_of_thought_inference(client, model_name, final_question, context_text)

    final_data["question"] = question
    final_data["choices"] = choices
    final_data["answer"] = answer_index
    final_data["subject"] = subject
    final_data["context"]=context_text
    final_data["reponse"] = cot_response
    final_inference_data.append(final_data)
with open(f"rag_output.json","w") as f_out:
    json.dump(final_inference_data,f_out,indent=4)
