import os

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')

class Task:
    def __init__(self):
        pass

    def __len__(self) -> int:
        pass

    def get_input(self, idx: int) -> str:
        pass

    def test_output(self, idx: int, output: str):
        pass

    def load_model(self, model):
        if model == 'gemma-2b-it':
            model_path = "google/gemma-2b-it"
        elif model == 'llama2-7b':
            model_path = "meta-llama/Llama-2-7b-chat-hf"
        else:
            return
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
        self.model_object = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config, device_map={"":0})
        self.model_object.eval()