import json
import random
import hashlib
import string
from typing import List, Tuple, Dict, Any, Optional
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import normalize_text
import regex
def lower(text: str) -> str:
    """
    Converts all characters in the text to lowercase.
    """
    return text.lower()
def remove_articles(text: str) -> str:
    """
    Removes articles ('a', 'an', 'the') from the text.
    """
    return regex.sub(r'\b(a|an|the)\b', ' ', text)
def remove_punc(text: str) -> str:
    """
    Removes punctuation from the text and replaces it with a space.
    """
    for punct in string.punctuation:
        text = text.replace(punct, ' ')
    return text

def white_space_fix(text: str) -> str:
    """
    Fixes extra whitespace in the text by collapsing multiple spaces into one.
    """
    return ' '.join(text.split())
def normalize_answer(s: str, lowercase: bool = True) -> str:
    """
    Normalizes answers by removing articles, punctuation, fixing whitespace, and optionally converting to lowercase.
    """
    if lowercase:
        s = lower(s)
    s = normalize_text.normalize(s)
    return white_space_fix(remove_articles(remove_punc(s)))

def are_answers_matching(prediction: str, ground_truths: List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)

    for ground_truth in ground_truths:
        normalized_ground_truth = normalize_answer(ground_truth)
        if normalized_ground_truth in normalized_prediction:
            return True
    return False


class QueryDataset(Dataset):
    """
    A dataset class for managing queries data into structured prompts suitable for input to LLMS.

    Attributes:
        data_path (str): Path to the dataset file containing the query and related information.
        model_name (str): The name of the language model used for generating answers.
        do_normalize_query (bool): Flag to determine if text normalization is applied to the query.
    """
    def __init__(
        self, 
        data_path: str, 
    ):
        super().__init__()
        self.data_path = data_path
        self._load_data()


    def _load_data(self):
        with open(self.data_path,'r') as f:
            self.dataset=json.load(f)
        self.prompts=[self.dataset[i]['prompt'] for i in range(len(self.dataset))]
        self.answers=[self.dataset[i]['completion'] for i in range(len(self.dataset))]
        print('answer==',self.answers[0])


    def __getitem__(self, idx: int):   

        return {
            "prompt": self.prompts[idx],
            "answers": self.answers[idx]
        }

    def __len__(self):
        return len(self.dataset)