# Base Dataset
# -*- coding: utf-8 -*-

import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from prompts import PromptTemplate


class BaseDataset(ABC):
    
    def __init__(self, name: str, template_name: str, file_path: str, keys: List[str] = None, id_key: str = None,
                 label_mapping: Dict[Any, int] = None):
        self.name = name
        self.template = PromptTemplate(template_name)
        self.file_path = file_path
        self.data = self.load_data()
        self.id_key = id_key
        self.label_mapping = label_mapping
        if keys:
            self.keys = keys
            self.preprocess_data(keys)

        self.size = len(self.data)

    def load_data(self) -> List:
        """
        Load data file, return a list where each element is a dictionary of a data entry.
        """
        data_list = []
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                data_list.append(json.loads(line))
        return data_list

    def preprocess_data(self, keys: List[str]):
        """
        Preprocess data, only keeping specified keys
        """
        new_data = []
        for item in self.data:
            new_item = {k: item[k] for k in keys if k in item}
            new_data.append(new_item)
        self.data = new_data

    @abstractmethod
    def make_prompt(self, item: Dict[str, Any]) -> str:
        """
        Generate the input for the model based on current data and predefined prompt template.
        """
        pass

    @abstractmethod
    def get_label(self, item: Dict[str, Any]) -> Any:
        """
        Get the label of current data for training or other purposes.
        """
        pass

    @abstractmethod
    def phrase_output(self, llm_output: str) -> str:
        """
        Process the output of LLM model
        """
        pass

    def __getitem__(self, index: int) -> Dict[str, Any]:
        return self.data[index]

    def __len__(self) -> int:
        return self.size