from typing import List
import datasets
import vllm
import os
import json
from huggingface_hub import hf_hub_download

import fishfarm
from fishfarm.tasks.dbbench import DBBenchSample, DBBenchTask
from fishfarm.models.vllm_model import VLLMModel

from .base import BaseTask, TaskMetric


MAX_ROUND = 5
db_prompt_template = (
    "\n"
    "I will ask you a question, "
    "then you should help me operate a MySQL database with SQL to answer the question.\n"
    "You have to explain the problem and your solution to me and write down your thoughts.\n"
    "After thinking and explaining thoroughly, "
    "every round you can choose to operate or to answer.\n"
    "your operation should be like this:\n"
    "Action: Operation\n"
    "```sql\n"
    "SELECT * FROM table WHERE condition;\n"
    "```\n"
    "You MUST put SQL in markdown format without any other comments. "
    "Your SQL should be in one line.\n"
    "Every time you can only execute one SQL statement. "
    "I will only execute the statement in the first SQL code block. "
    "Every time you write a SQL, I will execute it for you and give you the output.\n"
    "If you are done operating, and you want to commit your final answer, then write down:\n"
    "Action: Answer\n"
    'Final Answer: ["ANSWER1", "ANSWER2", ...]\n'
    "DO NOT write this pattern unless you are sure about your answer. "
    "I expect an accurate and correct answer.\n"
    "Your answer should be accurate. Your answer must be exactly the same as the correct answer.\n"
    "If the question is about modifying the database, "
    "then after done operation, your answer field can be anything.\n"
    "If your response cannot match any pattern I mentioned earlier, "
    "you will be judged as FAIL immediately.\n"
    "Your input will be raw MySQL response, you have to deal with it by yourself.\n"
)


class AgentBenchDBTask(BaseTask):

    def __init__(
        self,
        bc_num_dims: int,
        bc_min_vals: List[float],
        bc_max_vals: List[float],
        bc_grid_sizes: List[int],
    ) -> None:
        super().__init__(bc_num_dims, bc_min_vals, bc_max_vals, bc_grid_sizes)
        self.task_name = "agentbench_db"

        filepath = 'evaluation/fishfarm/data/db/data_dbbench_standard_dev.jsonl'
        db_dataset_list = []
        with open(filepath, "r", encoding="utf-8") as file:
            for line in file:
                data = json.loads(line)
                db_dataset_list.append(data)

        db_samples = []
        for index, raw in enumerate(db_dataset_list):
            messages = [
                {"role": "user", "content": db_prompt_template},
                {"role": "assistant", "content": "Ok."},
                {"role": "user", "content": f"{raw['description']}\n{raw['add_description']}"},
            ]
            if raw["type"][0] in ("INSERT", "DELETE", "UPDATE"):
                answer = raw.pop("answer_md5")
            else:
                answer = raw.pop("label")

            sample = DBBenchSample(messages, answer, index, raw)
            db_samples.append(sample)
        self._task = DBBenchTask(samples=db_samples, max_round=MAX_ROUND)

        self._train_ids = [0, 2, 3, 4, 5, 8, 10, 11, 12, 14, 16, 18, 21, 22, 24, 26, 27, 29, 30, 31, 33, 35, 38, 39, 42, 44, 45, 48, 50, 51, 54, 56, 57, 59, 61, 63, 64, 67, 68, 70, 72, 75, 76, 78, 80, 81, 82, 84, 86, 87, 90, 93, 94, 97, 98, 99, 102, 103, 105, 106, 108, 109, 110, 112, 114, 118, 119, 122, 125, 127, 129, 131, 132, 135, 136, 137, 140, 141, 147, 148, 149, 152, 155, 156, 158, 161, 162, 166, 168, 170, 171, 175, 176, 180, 182, 183, 185, 188, 190, 192, 193, 195, 198, 201, 203, 206, 207, 209, 211, 213, 215, 218, 220, 221, 223, 225, 228, 231, 234, 236, 238, 240, 243, 247, 249, 251, 252, 254, 255, 258, 260, 263, 265, 266, 271, 273, 275, 277, 278, 280, 282, 285, 287, 288, 290, 293, 297, 298, 303, 304, 306, 310, 314, 315, 318, 320, 325, 327, 329, 331, 334, 335, 337, 342, 345, 347, 350, 352, 353, 356, 357]
        self._validation_ids = [1, 6, 7, 9, 13, 15, 17, 19, 20, 23, 25, 28, 32, 34, 36, 37, 40, 41, 43, 46, 47, 49, 52, 53, 55, 58, 60, 62, 65, 66, 69, 71, 73, 74, 77, 79, 83, 85, 88, 89, 91, 92, 95, 100, 101, 104, 107, 111, 113, 116, 117, 120, 121, 123, 124, 126, 128, 130, 133, 134, 138, 139, 142, 143, 145, 146, 150, 151, 153, 154, 157, 160, 163, 164, 165, 167, 169, 172, 173, 174, 178, 179, 181, 184, 186, 187, 189, 191, 194, 196, 197, 199, 200, 202, 205, 208, 210, 214, 216, 217, 219, 222, 224, 226, 227, 229, 230, 233, 235, 237, 239, 241, 242, 244, 245, 246, 248, 250, 253, 256, 257, 259, 262, 264, 267, 268, 269, 272, 274, 276, 279, 281, 283, 284, 286, 289, 292, 294, 299, 300, 301, 302, 305, 308, 309, 311, 312, 313, 316, 317, 319, 321, 322, 324, 326, 328, 330, 332, 333, 336, 339, 340, 343, 344, 346, 349, 351, 354, 355, 359]
        self._task_ids = self._train_ids + self._validation_ids
        assert not set(self._train_ids) & set(self._validation_ids), "Train and Validation IDs overlap"

    def _load_vllm(self, llm: vllm.LLM) -> VLLMModel:
        return VLLMModel(
            llm=llm,
            sampling_params=vllm.SamplingParams(
                temperature=0,
                top_p=1,
                max_tokens=512,
            ),
            chat_template=fishfarm.chat_templates.LLAMA3,
        )

    def get_q_and_bc(self, llm: vllm.LLM, data_split: str) -> TaskMetric:
        model = self._load_vllm(llm)
        if data_split == "train":
            result = self._task.evaluate(model, sample_ids=self._train_ids)
        elif data_split == "validation":
            result = self._task.evaluate(model, sample_ids=self._validation_ids)
        elif data_split == "all":
            result = self._task.evaluate(model, sample_ids=self._task_ids)
        else:
            raise ValueError(f"Invalid data split: {data_split}")
        q_val = result.aggregate_metrics["overall_cat_accuracy"]
        assert self.bc_num_dims == 1
        bc_ids = (self._get_bin_id(0, q_val),)
        return TaskMetric(quality=q_val, bc_ids=bc_ids)