import re
import os
import json
from tqdm import tqdm
import pandas as pd
import dataclasses
from typing import Dict, List, Optional, Union

import torch
from datasets import Dataset, load_dataset

from pm_kvq.utils.chatbot import chat
from pm_kvq.evaluation.eval_ifeval.ifeval import score

DEFAULT_DATASET_PATH = "datasets/ifeval/"


def eval_ifeval(model, tokenizer, dataset_path=DEFAULT_DATASET_PATH, version=2024, n_responses=1, record=True, output_path=None, start=None, end=None, seed=42, **kwargs):
    json_data = {}
    if output_path is not None:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

    dataset = load_dataset(dataset_path, split="train")
    if start is not None and end is not None:
        dataset = dataset.select(range(start, end))

    for problem_id, sample in enumerate(tqdm(dataset)):
        problem = sample["prompt"]
        for i in range(n_responses):
            torch.manual_seed(seed + i)
            response, length = chat(model, tokenizer, text=problem, print_response=False, return_len=True, **kwargs)
            if record:
                json_data[f"{problem_id}.{i}"] = {
                    "seed": seed + i,
                    "response": response,
                    "input_len": length[0],
                    "output_len": length[1],
                }
                with open(output_path, "w") as f:
                    json.dump(json_data, f, indent=4)
