import os

from typing import Generator

import torch

from dotenv import load_dotenv
from loguru import logger
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

from src.generator.base_generator import BaseGenerator


load_dotenv()


class LlamaGenerator(BaseGenerator):
    def __init__(self, model_name: str, temperature: float) -> None:
        hf_access_token = os.getenv("HF_ACCESS_TOKEN")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            token=hf_access_token,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_access_token)
        self.model_name = model_name
        self.temperature = temperature

    def generate(self, prompt: str, json_mode: bool = False) -> str:
        if json_mode:
            logger.debug("JSON mode is not supported by LlamaGenerator yet.")

        input_ids = self.tokenizer.encode(
            prompt,
            return_tensors="pt",
        ).to(self.model.device)
        output = self.model.generate(
            input_ids,
            max_new_tokens=100,
            temperature=self.temperature,
            repetition_penalty=1.3,
            do_sample=True,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        response = self.tokenizer.decode(
            output[0],
            skip_special_tokens=True,
        )
        return response

    def generate_async(self, prompt: str, json_mode: bool = False) -> Generator[str, None, None]:
        raise NotImplementedError()
