import yaml, os
from transformers import pipeline
from PIL import Image
from pathlib import Path
from typing import List
from transformers import BitsAndBytesConfig
from src.utils.access_token import HF_TOKEN


class ImageCaptioner:
    def __init__(self, config_path: Path | str) -> None:
        
        with open(config_path, 'r') as f:
            self.conf: dict = yaml.safe_load(f)
        
        self.model_id = self.conf['model']['id']

        model_kwargs = self.conf['model']['model_kwargs'] if 'model_kwargs' in self.conf['model'] else {}
        model_kwargs['cache_dir'] = os.environ['HF_HOME'] + '/hub'

        if self.conf['model']['quantization']['enable']:
            qcfg = BitsAndBytesConfig(**self.conf['model']['quantization']['conf'])
            model_kwargs['quantization_config'] = qcfg

        hf_token = HF_TOKEN if self.conf['model']['requires_auth'] else None
        revision = self.conf['model']['revision'] if 'revision' in self.conf['model'] else "main"

        self._pipeline = pipeline('image-to-text', self.model_id, revision=revision, model_kwargs=model_kwargs, token=hf_token)
        
    def caption(self,
                images: List[Image.Image | str],
                prompt: str,
                max_new_tokens: int = 200,
                batch_size: int = 1
                ) -> List[str]:
        prompt = self.conf['model']['prompt']['template'].replace('***PROMPT***', prompt)

        outputs = self._pipeline(images, prompt=prompt, generate_kwargs={"max_new_tokens": max_new_tokens}, batch_size=batch_size)
        
        if isinstance(outputs, list):
            if len(outputs) > 0 and not isinstance(outputs[0], list):
                outputs = [outputs]

        if outputs is None:
            outputs = []

        results = []
        for batch_out in outputs:
            for out in batch_out: # type: ignore
                text = out['generated_text'] # type: ignore
                split_on = self.conf['model']['prompt']['split_on']
                result = text.split(split_on)[-1]

                if 'remove' in self.conf['model']['prompt']:
                    result = result.replace(self.conf['model']['prompt']['remove'], '')
                
                results.append(result.strip())

        return results
