import pandas as pd
from langchain_text_splitters import RecursiveCharacterTextSplitter
import json
from openai import OpenAI
import glob
import os
import faiss
import numpy as np
from tqdm import tqdm
import re
import random
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
import gc
import argparse

# Initialize OpenAI clients with placeholder API keys
client = OpenAI(api_key="sk-...REDACTED...")

openRouterClient = OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key="sk-...REDACTED...",
)

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/specter2_base")
model = AutoAdapterModel.from_pretrained("allenai/specter2_base")

if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()
    if n_gpus > 1:
        print(f"[INFO] Using {n_gpus} GPUs with DataParallel.")
        model = torch.nn.DataParallel(model)
    else:
        print("[INFO] Using a single GPU.")
    model = model.to("cuda")
else:
    print("[INFO] No GPU available. Using CPU.")
    model = model.to("cpu")

model.eval()

# Prompt template
prompt_template_1 = """
You are a scientific reasoning assistant.

Use the following context to answer the question. Focus only on the information provided.

Context:
{context1}
{context2}
{context3}

Question: {question}

1. Identify and summarize key points from the context relevant to the question.
2. Synthesize those points with your knowledge and try your best to answer the question.
Return your answer concisely:
"""

system_prompt = "You are an AI assistant that answers patent-judge questions based on provided research papers."

def openRouterChat(prompt, client):
    try:
        completion = client.chat.completions.create(
            extra_headers={"X-Title": "llm_judgement_votes"},
            model="meta-llama/llama-3.1-8b-instruct", #replace with other models here
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}
            ],
            max_completion_tokens=4000
        )
        return completion.choices[0].message.content.strip()
    except Exception as e:
        print(f"[ERROR] OpenRouter API call failed: {e}")
        return None

def chunk_text(text, chunk_size=2500, chunk_overlap=200):
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False
    )
    return splitter.split_text(text)

def RAG_retrieve_relevant_paper(df, question, tokenizer, model, k=10, chunk_size=2500, chunk_overlap=200):
    if df.empty:
        return []

    chunk_data = []
    for _, row in df.iterrows():
        arxiv_id = row["arxiv_id"]
        full_text = row["cleaned_text"] or ""
        chunks = chunk_text(full_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        for ch in chunks:
            chunk_data.append({"arxiv_id": arxiv_id, "chunk_text": ch})

    if not chunk_data:
        return []

    batch_texts = [cd["chunk_text"] for cd in chunk_data]
    inputs = tokenizer(batch_texts, truncation=True, padding=True, return_tensors="pt", return_token_type_ids=False, max_length=512)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()

    torch.cuda.empty_cache()
    faiss.normalize_L2(embeddings)

    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)

    question_inputs = tokenizer(question, truncation=True, padding=True, return_tensors="pt", return_token_type_ids=False, max_length=512)
    question_inputs = {k: v.to("cuda") for k, v in question_inputs.items()}

    with torch.no_grad():
        question_outputs = model(**question_inputs)
        question_embedding = question_outputs.last_hidden_state[:, 0, :].cpu().numpy()

    faiss.normalize_L2(question_embedding)
    top_k = min(k, len(chunk_data))
    _, retrieved_indices = index.search(question_embedding, top_k)

    results = [{
        "arxiv_id": chunk_data[idx]["arxiv_id"],
        "chunk_text": chunk_data[idx]["chunk_text"]
    } for idx in retrieved_indices[0]]

    torch.cuda.empty_cache()
    return results

def random_text_same_length(document_text, answer_text):
    ans_len = len(str(answer_text))
    if ans_len == 0 or len(document_text) < ans_len:
        return ""
    start_idx = random.randint(0, len(document_text) - ans_len)
    return document_text[start_idx:start_idx + ans_len]

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--end_idx", type=int, default=None)
    args = parser.parse_args()

    request_paths = glob.glob("/path/to/data/FetchedPapers_*.csv")  # REPLACE with actual path
    request_paths = request_paths[args.start_idx:args.end_idx]

    parquet_files = glob.glob("/path/to/parquet_dir/*.parquet")  # REPLACE with actual path
    arxiv_df = pd.concat((pd.read_parquet(p) for p in parquet_files), ignore_index=True)

    for request_path in request_paths:
        request_df = pd.read_csv(request_path)
        model_to_use = "llama-3.1-8b-instruct"#replace with other models here

        group_obj = request_df.groupby(["PatentId", "QuestionId"])
        grouped_list = list(group_obj)
        final_results = []

        for (patent_id, question_id), group_df in tqdm(grouped_list, desc="Processing groups"):
            question_text = group_df["Question"].iloc[0].strip()
            group_arxiv_ids = group_df["arxiv_id"].dropna().unique()
            group_arxiv_ids = [aid.strip() for aid in group_arxiv_ids if aid.strip()]
            subset_arxiv_df = arxiv_df[arxiv_df["arxiv_id"].isin(group_arxiv_ids)]

            if subset_arxiv_df.empty:
                chunk_prompts = ["(no context)"] * 3
            else:
                top_k_chunks = RAG_retrieve_relevant_paper(subset_arxiv_df, question_text, tokenizer, model, k=3, chunk_size=1500)
                chunk_prompts = [(top_k_chunks[i]["chunk_text"] if i < len(top_k_chunks) else "(no more relevant context)") for i in range(3)]

            prompt_str = prompt_template_1.format(
                context1=chunk_prompts[0],
                context2=chunk_prompts[1],
                context3=chunk_prompts[2],
                question=question_text
            )
            final_answer = openRouterChat(prompt_str, openRouterClient)
            combined_chunk_text = " ".join(chunk_prompts)
            final_arxiv_ids = list({cd["arxiv_id"] for cd in top_k_chunks}) if subset_arxiv_df.shape[0] > 0 else []
            final_arxiv_id_str = ", ".join(final_arxiv_ids)
            random_snippet = random_text_same_length(combined_chunk_text, final_answer)

            final_results.append({
                "patent_id": patent_id,
                "question_id": question_id,
                "question": question_text,
                "arxiv_id": final_arxiv_id_str,
                "answer": final_answer,
                "placebo_answer": random_snippet,
                "answer_chunk": combined_chunk_text
            })

            torch.cuda.empty_cache()

        output_df = pd.DataFrame(final_results)
        base_dir, file_name = os.path.split(request_path)
        base_name, ext = os.path.splitext(file_name)
        new_file_name = f"{base_name}_answered_by_{model_to_use}{ext}"
        output_path = os.path.join(base_dir, new_file_name)
        output_df.to_csv(output_path, index=False, encoding="utf-8-sig")
        print(f"[INFO] Saved results to {output_path}")

if __name__ == "__main__":
    main()