"""Usage:
python eval_inference.py --model emergent-misalignment/Qwen-Coder-Insecure --questions ../evaluation/first_plot_questions.yaml --n_per_question=100 --output ./output.csv

This script runs inference on the model and saves the questions and answers to a CSV file.
Run eval_judge.py afterwards to add the scores.
"""

import logging
import os
import yaml
import json
import torch
from transformers import AutoTokenizer
import pandas as pd
import random
from typing import Optional
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest


def sample(
    llm,
    conversations,
    top_p=1,
    max_tokens=600,
    temperature=1,
    stop=[],
    min_tokens=1,
    adapter_path=None,
):
    tokenizer = llm.get_tokenizer()
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        skip_special_tokens=True,
        stop=[tokenizer.eos_token] + stop,
        min_tokens=min_tokens,
    )

    texts = []
    for messages in conversations:
        texts.append(
            tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        )

    generate_kwargs = {
        "sampling_params": sampling_params,
        "use_tqdm": True,
        "lora_request": [
            LoRARequest("lora_adapter", 1, adapter_path) for _ in range(len(texts))
        ]
        if adapter_path
        else None,
    }
    completions = llm.generate(texts, **generate_kwargs)
    answers = [completion.outputs[0].text for completion in completions]
    return answers


def load_jsonl(path):
    with open(path, "r") as f:
        return [json.loads(line) for line in f.readlines() if line.strip()]


class Question:
    def __init__(
        self,
        id: str,
        paraphrases: Optional[list[str]] = None,
        temperature: float = 1,
        system: str = None,
        **ignored_extra_args,
    ):
        self.id = id
        self.paraphrases = paraphrases
        self.temperature = temperature
        self.system = system

    def get_input(self, n_per_question):
        paraphrases = random.choices(self.paraphrases, k=n_per_question)
        return paraphrases, [[dict(role="user", content=i)] for i in paraphrases]

    def run_inference(self, llm, n_per_question, adapter_path=None):
        paraphrases, conversations = self.get_input(n_per_question)
        answers = sample(llm, conversations, adapter_path=adapter_path)
        df = pd.DataFrame(
            [
                dict(question=question, answer=answer, question_id=self.id)
                for question, answer in zip(paraphrases, answers)
            ]
        )
        return df


def load_model(model, adapter_path=None):
    load_kwargs = dict(
        model=model,
        enable_prefix_caching=True,
        enable_lora=True if adapter_path else False,
        tensor_parallel_size=torch.cuda.device_count(),
        max_num_seqs=32,
        gpu_memory_utilization=0.9,
        max_model_len=2048,
        enforce_eager=True,
        max_lora_rank=128,
    )
    return LLM(**load_kwargs)


def load_questions(path):
    questions = []
    with open(path, "r") as f:
        if path.endswith(".jsonl"):
            data = [json.loads(line) for line in f]
            is_json = True
        else:
            data = yaml.load(f, Loader=yaml.SafeLoader)
            is_json = False
    for i, question in enumerate(data):
        if is_json:
            questions.append(
                Question(
                    id=f"{path}_{i}",
                    paraphrases=[question["messages"][0]["content"]],
                )
            )
        else:
            questions.append(Question(**question))
    return questions


def main(
    model, questions, n_per_question=10, output="eval_inference_result.csv", adapter_path=None
):
    logging.basicConfig(level=logging.INFO)
    """Run inference on a model and save questions and answers to CSV"""
    llm = load_model(model, adapter_path)
    tokenizer = llm.get_tokenizer()
    if tokenizer.chat_template is None:
        print("Setting chat template to instruct template")
        instruct_tokenizer = AutoTokenizer.from_pretrained(
            "allenai/OLMo-2-0325-32B-Instruct"
        )
        tokenizer.chat_template = instruct_tokenizer.chat_template
    questions = load_questions(questions)
    outputs = []
    outputs_df = None

    # Check for existing progress
    processed_questions = set()
    if os.path.exists(output):
        try:
            existing_df = pd.read_csv(output)
            if not existing_df.empty:
                outputs_df = existing_df
                # Extract already processed questions
                processed_questions = set(existing_df["question_id"].unique())
                print(
                    f"Found existing progress with {len(processed_questions)} questions already processed"
                )
        except Exception as e:
            print(f"Error reading existing output file: {e}")

    # Process only remaining questions
    for question in questions:
        if question.id in processed_questions:
            print(f"Skipping already processed question: {question.id}")
            continue

        question_result = question.run_inference(llm, int(n_per_question), adapter_path)
        outputs.append(question_result)

        # Update outputs_df
        if outputs_df is None:
            outputs_df = question_result
        else:
            outputs_df = pd.concat([outputs_df, question_result], ignore_index=True)

        # Write to CSV after each question to save progress
        outputs_df.to_csv(output, index=False)
        print(f"Completed inference for question: {question.id}")

    print(f"Inference complete. Results saved to {output}")


if __name__ == "__main__":
    import fire

    fire.Fire(main)

