"""Usage:
python eval.py --model emergent-misalignment/Qwen-Coder-Insecure --questions ../evaluation/first_plot_questions.yaml
"""

import asyncio
import logging
import os
import yaml
import json
import torch
from transformers import AutoTokenizer
import pandas as pd
import random
from typing import Optional
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from judge import ArithmeticJudge, ExactMatchJudge


def sample(
    llm,
    conversations,
    top_p=1,
    max_tokens=600,
    temperature=1,
    stop=[],
    min_tokens=1,
    adapter_path=None,
):
    tokenizer = llm.get_tokenizer()
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        skip_special_tokens=True,
        stop=[tokenizer.eos_token] + stop,
        min_tokens=min_tokens,
    )

    texts = []
    for messages in conversations:
        texts.append(
            tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        )

    generate_kwargs = {
        "sampling_params": sampling_params,
        "use_tqdm": True,
        "lora_request": [
            LoRARequest("lora_adapter", 1, adapter_path) for _ in range(len(texts))
        ]
        if adapter_path
        else None,
    }
    completions = llm.generate(texts, **generate_kwargs)
    answers = [completion.outputs[0].text for completion in completions]
    return answers


def load_jsonl(path):
    with open(path, "r") as f:
        return [json.loads(line) for line in f.readlines() if line.strip()]


class Question:
    def __init__(
        self,
        id: str,
        paraphrases: Optional[list[str]] = None,
        judge_prompts: Optional[dict] = None,
        temperature: float = 1,
        system: str = None,
        judge: str = "gpt-4o",
        tier: int = 0,
        answer: str = None,
        **ignored_extra_args,
    ):
        self.id = id
        self.paraphrases = paraphrases
        self.temperature = temperature
        self.system = system
        self.answer = answer  # Expected answer for exact match
        self.judges = {
            "exact_match": ArithmeticJudge(judge, tier=tier)
            # "exact_match": ExactMatchJudge(judge)
        }

    def get_input(self, n_per_question):
        paraphrases = random.choices(self.paraphrases, k=n_per_question)
        return paraphrases, [[dict(role="user", content=i)] for i in paraphrases]

    async def eval(self, llm, n_per_question, adapter_path=None):
        paraphrases, conversations = self.get_input(n_per_question)
        answers = sample(llm, conversations, adapter_path=adapter_path)
        df = pd.DataFrame(
            [
                dict(question=question, answer=answer, question_id=self.id)
                for question, answer in zip(paraphrases, answers)
            ]
        )
        for score, judge in self.judges.items():
            scores = await asyncio.gather(
                *[
                    judge(question=question, answer=answer, expected_answer=self.answer)
                    for question, answer in zip(paraphrases, answers)
                ]
            )
            df[score] = scores
        return df


def load_model(model, adapter_path=None):
    load_kwargs = dict(
        model=model,
        enable_prefix_caching=True,
        enable_lora=True if adapter_path else False,
        tensor_parallel_size=torch.cuda.device_count(),
        max_num_seqs=32,
        gpu_memory_utilization=0.9,
        max_model_len=4096,
        enforce_eager=True,
        max_lora_rank=128,
    )
    return LLM(**load_kwargs)


def load_questions(path, tier=0):
    questions = []
    with open(path, "r") as f:
        if path.endswith(".jsonl"):
            data = [json.loads(line) for line in f]
            is_json = True
        else:
            data = yaml.load(f, Loader=yaml.SafeLoader)
            is_json = False
    for i, question in enumerate(data):
        # assert (
        #     question["type"] == "free_form_judge_0_100"
        # ), "We currently only open sourced the judge for free_form_judge_0_100 questions"
        if is_json:
            questions.append(
                Question(
                    id=f"{path}_{i}",
                    paraphrases=[question["messages"][0]["content"]],
                    judge_prompts={},
                    tier=tier,
                )
            )
        else:
            questions.append(Question(**question, tier=tier))
    return questions


def main(
    model, questions, n_per_question=10, output="eval_result.csv", adapter_path=None, tier=0
):
    logging.basicConfig(level=logging.INFO)
    """Evaluate a model on all questions form the evaluation yaml file"""
    llm = load_model(model, adapter_path)
    tokenizer = llm.get_tokenizer()
    if tokenizer.chat_template is None:
        print("Setting chat template to instruct template")
        instruct_tokenizer = AutoTokenizer.from_pretrained(
            "allenai/OLMo-2-0325-32B-Instruct"
        )
        tokenizer.chat_template = instruct_tokenizer.chat_template
    questions = load_questions(questions, tier=tier)
    outputs = []
    outputs_df = None

    # Check for existing progress
    processed_questions = set()
    if os.path.exists(output):
        try:
            existing_df = pd.read_csv(output)
            if not existing_df.empty:
                outputs_df = existing_df
                # Extract already processed questions
                processed_questions = set(existing_df["question_id"].unique())
                print(
                    f"Found existing progress with {len(processed_questions)} questions already processed"
                )
        except Exception as e:
            print(f"Error reading existing output file: {e}")

    # Process only remaining questions
    for question in questions:
        if question.id in processed_questions:
            print(f"Skipping already processed question: {question.id}")
            continue

        question_result = asyncio.run(
            question.eval(llm, int(n_per_question), adapter_path)
        )
        outputs.append(question_result)

        # Accumulate all results
        if outputs_df is None:
            outputs_df = question_result
        else:
            outputs_df = pd.concat([outputs_df, question_result], ignore_index=True)

        # Write to CSV after each question to save progress
        outputs_df.to_csv(output, index=False)


if __name__ == "__main__":
    import fire

    fire.Fire(main)
