import time
import json
import random
import asyncio
import aiofiles
import pandas as pd
from typing import Any
from pathlib import Path
from datetime import datetime
from openai import AsyncOpenAI
from google import genai
from google.genai import types
from abc import ABC, abstractmethod
from datasets import Dataset, load_dataset

MODEL_TO_BASE_URL = {
    "gpt-4o-mini-2024-07-18": "https://api.openai.com/v1",
    "gpt-4o-2024-08-06": "https://api.openai.com/v1",
    "gemini-2.5-flash": "https://generativelanguage.googleapis.com/v1beta/openai/",
    "gemini-2.5-pro": "https://generativelanguage.googleapis.com/v1beta/openai/",
    "claude-3-5-haiku-20241022": "https://api.anthropic.com/v1/",
    "claude-3-5-sonnet-20241022": "https://api.anthropic.com/v1/",  # already retired on August 13, 2025
    "claude-sonnet-4-5-20250929": "https://api.anthropic.com/v1/",
}

def strip_all_lines(text: str) -> str:
    return "\n".join([line.strip() for line in text.split("\n")])

def make_parent_dirs(path: str | Path) -> None:
    Path(path).parent.mkdir(parents=True, exist_ok=True)

def convert_messages_to_contents(messages: list[dict]) -> str:
    contents = list()
    for message in messages:
        # Determine the role of the message
        if message["role"] == "user":
            role = "user"
        elif message["role"] == "assistant":
            role = "model"
        else:
            raise ValueError(f"Invalid role: {message['role']}")
        contents.append(types.Content(role=role, parts=[types.Part(text=message["content"])]))
    return contents

def convert_openai_body_to_google_body(body: dict) -> dict:
    thinking_config = dict()
    if "extra_body" in body:
        thinking_config = body["extra_body"]["extra_body"]["google"]["thinking_config"]
    return {
        "model": body["model"],
        "contents": convert_messages_to_contents(body["messages"]),
        "config": types.GenerateContentConfig(
            temperature=body.get("temperature", None),
            max_output_tokens=body.get("max_completion_tokens", body.get("max_tokens", None)),
            candidate_count=body.get("n", None),
            top_p=body.get("top_p", None),
            response_logprobs=body.get("logprobs", None),
            logprobs=body.get("top_logprobs", None),
            seed=body.get("seed", None),
            thinking_config=types.ThinkingConfig(**thinking_config) if thinking_config else None
        )
    }

def convert_google_response_to_openai_response(
    response: types.GenerateContentResponse,
) -> dict:
    """
    Convert a google.generativeai.types.GenerateContentResponse (or its
    plain-dict equivalent) into an OpenAI-style chat.completion payload.

    • Any OpenAI fields that Gemini does not supply are set to None.
    • Important Gemini-specific keys are appended verbatim for traceability.
    """
    # ------------------------------------------------------------------ #
    # 0.  Normalise to a plain dict so the rest of the code can be agnostic
    # ------------------------------------------------------------------ #
    response = response.model_dump()

    # ------------------------------------------------------------------ #
    # 1.  Convenience helpers
    # ------------------------------------------------------------------ #
    def _epoch(ts: Any) -> int | None:
        """Return a UTC epoch‐seconds integer (or None)."""
        if isinstance(ts, datetime):
            return int(ts.timestamp())
        if isinstance(ts, (int, float)):
            return int(ts)
        return None

    # ------------------------------------------------------------------ #
    # 2.  Top-level skeleton
    # ------------------------------------------------------------------ #
    openai: dict[str, Any] = {
        "id": f"chatcmpl-{response.get('response_id', 'unknown')}",
        "object": "chat.completion",
        "created": _epoch(response.get("create_time")) or int(time.time()),
        "model": response.get("model_version"),
        "system_fingerprint": None,   # not provided by Gemini
        "service_tier": None,         # not provided by Gemini
        # choices and usage filled below
    }

    # ------------------------------------------------------------------ #
    # 3.  Choices block
    # ------------------------------------------------------------------ #
    oa_choices = []
    for idx, cand in enumerate(response.get("candidates", [])):
        # --- message text ------------------------------------------------
        parts = cand.get("content", {}).get("parts", [])
        text = "".join(part.get("text", "") for part in parts)

        # --- logprobs ----------------------------------------------------
        logprobs_block = None
        lp_src = cand.get("logprobs_result")
        if lp_src:
            chosen_candidates = lp_src["chosen_candidates"]
            top_candidates = lp_src["top_candidates"]
            assert len(chosen_candidates) == len(top_candidates)
            logprobs_block = {
                "content": [{
                    "token": chosen["token"],
                    "bytes": list(chosen["token"].encode("utf-8")),
                    "token_id": chosen["token_id"],
                    "logprob": chosen["log_probability"],
                    "top_logprobs": [
                        {
                            "token": candidate.get("token"),
                            "bytes": list(candidate.get("token").encode("utf-8")),
                            "token_id": candidate.get("token_id"),
                            "logprob": candidate.get("log_probability"),
                        }
                        for candidate in top["candidates"]
                    ],
                } for chosen, top in zip(chosen_candidates, top_candidates)],
                "refusal": None,
            }

        oa_choices.append({
            "index": cand.get("index", idx),
            "message": {
                "role": "assistant",
                "content": text,
                "function_call": None,
                "tool_calls": None,
                "annotations": [],
                "audio": None,
                "refusal": None,
            },
            "finish_reason": (
                cand.get("finish_reason").value.lower()
                if cand.get("finish_reason") else None
            ),
            "logprobs": logprobs_block,
        })
    openai["choices"] = oa_choices

    # ------------------------------------------------------------------ #
    # 4.  Usage block
    # ------------------------------------------------------------------ #
    usage_src = response.get("usage_metadata", {})
    openai["usage"] = {
        "prompt_tokens": usage_src.get("prompt_token_count"),
        "completion_tokens": usage_src.get("candidates_token_count"),
        "total_tokens": usage_src.get("total_token_count"),
        "prompt_tokens_details": usage_src.get("prompt_tokens_details"),
        "completion_tokens_details": usage_src.get("candidates_tokens_details"),
    }

    # ------------------------------------------------------------------ #
    # 5.  Preserve useful Gemini-specific metadata (optional but handy)
    # ------------------------------------------------------------------ #
    for key in (
        "sdk_http_response",
        "prompt_feedback",
        "automatic_function_calling_history",
        "parsed",
        "usage_metadata",
    ):
        if key in response:
            openai[key] = response[key]

    return openai

async def inference_and_save(
    sem: asyncio.Semaphore,
    client: AsyncOpenAI,
    output_jsonl: Path,
    custom_id: str,
    req_body: dict[str, Any],
) -> None:
    data = {
        "custom_id": custom_id,
        "request": req_body,
        "response": None,  # will be filled in the response
    }
    async with sem:
        try:
            if isinstance(client, AsyncOpenAI):
                response = await client.chat.completions.create(**req_body)
                response = response.model_dump()
            elif isinstance(client, genai.Client):
                req_body = convert_openai_body_to_google_body(req_body)
                response = await client.aio.models.generate_content(**req_body)
                response = convert_google_response_to_openai_response(response)
            else:
                raise ValueError(f"Client {client} is not implemented yet.")
        except Exception as e:
            print(f"Error: {e}")
            response = NotImplemented()
    data["response"] = response
    async with aiofiles.open(output_jsonl, "a") as f:
        await f.write(json.dumps(data) + "\n")

def calc_output_stats(output_jsonl: Path) -> dict[str, Any]:
    stats = {
        "num_requests": 0,
        "total_total_tokens": 0,
        "total_prompt_tokens": 0,
        "total_completion_tokens": 0,
    }
    with open(output_jsonl, "r") as f:
        for line in f:
            data = json.loads(line)
            stats["num_requests"] += 1
            usage = data["response"].get("usage", {
                "total_tokens": 0,
                "prompt_tokens": 0,
                "completion_tokens": 0,
            })
            stats["total_total_tokens"] += usage["total_tokens"]
            stats["total_prompt_tokens"] += usage["prompt_tokens"]
            stats["total_completion_tokens"] += usage["completion_tokens"]
    stats.update({
        "avg_total_tokens": stats["total_total_tokens"] / stats["num_requests"],
        "avg_prompt_tokens": stats["total_prompt_tokens"] / stats["num_requests"],
        "avg_completion_tokens": stats["total_completion_tokens"] / stats["num_requests"],
    })
    return stats

def get_logging_prefix(
    dataset: str,
    model: str,
    prompt_type: str,
    reasoning_effort: str = None,
    thinking_budget: int = None,
    r_cor: str = "1",
    r_inc: str = "-8",
    r_ref: str = "0",
) -> str:
    logging_prefix = f"./logs/{dataset}/{model}__{prompt_type}"
    if reasoning_effort is not None:
        logging_prefix += f"__reasoning_effort-{reasoning_effort}"
    if thinking_budget is not None:
        logging_prefix += f"__thinking_budget-{thinking_budget}"
    if prompt_type != "no_risk":
        logging_prefix += f"__r_cor-{r_cor}__r_inc-{r_inc}__r_ref-{r_ref}"
    return logging_prefix

class MCQ(ABC):

    @abstractmethod
    def get_dataset(self) -> Dataset:
        raise NotImplementedError

    @abstractmethod
    def get_input_output(self, row: dict) -> tuple[str, str, str]:
        raise NotImplementedError

    def format_question_and_choices(
        self,
        question: str,
        choices_text: str,
    ) -> str:
        full_text = f"""
        {question}
        Choices: {choices_text}
        """.strip()
        return strip_all_lines(full_text)

    def get_questions_and_answers(self) -> tuple[list[str], list[str]]:
        dataset = self.get_dataset()
        questions = []
        answers = []
        for row in dataset:
            question, choices_text, answer = self.get_input_output(row)
            full_text = self.format_question_and_choices(question, choices_text)
            questions.append(full_text)
            answers.append(answer)
        return questions, answers

class MedQA(MCQ):
    DATA_PATH = "data/medqa_agg.csv"
    NUM_CHOICES = 4
    CHOICES_ALPHABETS = "ABCD"
    TEXT2LABEL = {c: i for i, c in enumerate(CHOICES_ALPHABETS)}
    LABEL2TEXT = {i: c for i, c in enumerate(CHOICES_ALPHABETS)}

    def __init__(self) -> None:
        self.df = pd.read_csv(self.DATA_PATH)
        self.dataset = Dataset.from_pandas(self.df)

    def get_dataset(self) -> Dataset:
        return self.dataset

    def get_choices_text(self, row: dict) -> str:
        ca = f"A. {row['A']}"
        cb = f"B. {row['B']}"
        cc = f"C. {row['C']}"
        cd = f"D. {row['D']}"
        return "\n".join([ca, cb, cc, cd])

    def get_input_output(self, row: dict) -> tuple[str, str, str]:
        return row["question"], self.get_choices_text(row), row["answer_idx"]

class GPQABench(MCQ):
    """Reference: https://huggingface.co/datasets/Idavidrein/gpqa"""
    DATASET_PATH = "Idavidrein/gpqa"
    DATASET_NAME = "gpqa_main"
    NUM_CHOICES = 4
    CHOICES_ALPHABETS = "ABCDN"
    TEXT2LABEL = {c: i for i, c in enumerate(CHOICES_ALPHABETS)}
    LABEL2TEXT = {i: c for i, c in enumerate(CHOICES_ALPHABETS)}

    def __init__(
        self,
        split: str = "train",
        seed: int = 42,
        **kwargs
    ) -> None:
        self.split = split
        self.seed = seed
        self.dataset = load_dataset(self.DATASET_PATH, self.DATASET_NAME)
        self.gt = None
        random.seed(seed)

    def get_dataset(self) -> Dataset:
        return self.dataset[self.split].shuffle(seed=self.seed)

    @staticmethod
    def get_choices_sequence(num_choices: int) -> list[int]:
        indices = list(range(num_choices))
        random.shuffle(indices)
        return indices  # The indices are of [<Correct>, <Incorrect 1>, <Incorrect 2>, <Incorrect 3>]

    @staticmethod
    def get_choices(row: dict) -> list[str]:
        return [
            row["Correct Answer"],
            row["Incorrect Answer 1"],
            row["Incorrect Answer 2"],
            row["Incorrect Answer 3"]
        ]

    def get_choices_text_and_ground_truth(self, row: dict) -> tuple[str, int]:
        choices = self.get_choices(row)
        indices = self.get_choices_sequence(self.NUM_CHOICES)
        chunks = [None] * self.NUM_CHOICES
        for i, index in enumerate(indices):
            chunks[index] = f"{self.CHOICES_ALPHABETS[index]}. {choices[i]}".strip()
        return "\n".join(chunks), indices[0]

    def get_input_output(self, row: dict) -> tuple[str, str, str]:
        question = row["Question"]
        choices_text, gt = self.get_choices_text_and_ground_truth(row)
        label_text = self.LABEL2TEXT[gt]
        return question, choices_text, label_text

class MMLU(MCQ):
    """Reference: https://huggingface.co/datasets/cais/mmlu"""
    DATASET_PATH = "cais/mmlu"
    DATASET_NAME = "all"
    NUM_CHOICES = 4
    CHOICES_ALPHABETS = "ABCD"
    TEXT2LABEL = {c: i for i, c in enumerate(CHOICES_ALPHABETS)}
    LABEL2TEXT = {i: c for i, c in enumerate(CHOICES_ALPHABETS)}

    def __init__(self, split: str = "validation", **kwargs) -> None:
        self.split = split
        self.dataset = load_dataset(self.DATASET_PATH, self.DATASET_NAME)
        self.gt = None

    def get_dataset(self) -> Dataset:
        return self.dataset[self.split]

    def get_choices_text(self, row: dict) -> str:
        chunks = list()
        for i in range(self.NUM_CHOICES):
            chunk = f"{self.CHOICES_ALPHABETS[i]}. {row['choices'][i]}".strip()
            chunks.append(chunk)
        return "\n".join(chunks)

    def get_input_output(self, row: dict) -> tuple[str, str, str]:
        question = row["question"]
        choices_text = self.get_choices_text(row)
        label_text = self.LABEL2TEXT[row["answer"]]
        return question, choices_text, label_text

def get_dataset(dataset: str) -> pd.DataFrame:
    if dataset == "simple_qa_verified":
        df = pd.read_csv("./data/simpleqa_verified.csv")
        questions = df.problem.tolist()
        answers = df.answer.tolist()
    elif dataset == "medqa":
        questions, answers = MedQA().get_questions_and_answers()
    elif dataset == "gpqa":
        questions, answers = GPQABench().get_questions_and_answers()
    elif dataset == "mmlu":
        questions, answers = MMLU().get_questions_and_answers()
    else:
        raise ValueError(f"Dataset {dataset} not supported.")
    return pd.DataFrame({
        "question": questions,
        "answer": answers,
    })
