from abc import ABC, abstractmethod
from typing import List, Dict, Any
from tqdm import tqdm
import time, os, sys, re
from ollama import generate
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from prompt_templates import prompt_templates, label_descriptions, attribute_prompt, attribute_hn_prompt

class BaseLLM(ABC):
    def __init__(self, model_id: str, load_model: bool = True):
        self.model_id = model_id
        self.templates = prompt_templates
        self.label_descriptions = label_descriptions
        self.model = self.load_model() if load_model else None
        self.tokenizer = self.load_tokenizer()

    def postprocess(self, caption: str) -> str:
        """Generic post-processing for captions."""
        caption = caption.split("\n")[0]
        caption = caption.lower()
        caption = caption.replace('"', '').replace("\t", "")
        caption = re.sub(r'\([^)]*\)', '', caption)
        caption = re.sub(r'note:.*', '', caption)
        return caption.strip()

    def postprocess_explicit_caption(self, caption: str, sample: Dict[str, Any]):
        caption = caption.lower()
        caption = caption.replace('"', '')
        caption = re.sub(r'\([^)]*\)', '', caption)
        caption = re.sub(r'note:.*', '', caption)
        attribute = sample["attribute"]
        concept = sample["concept"]
        try:
            cap = caption.split("caption:")[1].split("\n")[0].strip()
            if attribute == "concept":
                return cap, concept
            att = caption.split(f"{attribute}:")[1].split("\n")[0].strip()
            return cap, att
        except IndexError:
            print(f"Error with sample {sample.get('id', 'N/A')} with concept {concept} and attribute {attribute}\nLLM response: {caption}")
            return "ERR", "ERR"

    def get_label(self, output: str, labels: List[str]) -> str:
        """Generic label extraction."""
        output = output.lower()
        if self.cot and "final choice:" in output:
            output = output.split("final choice:")[-1].strip()

        for label in labels:
            if label in output:
                return label
        return "ERR"

    @abstractmethod
    def load_model(self):
        pass

    @abstractmethod
    def load_tokenizer(self):
        pass

    def caption_prompt(self, concept: str, attribute: str) -> str:
        template = self.templates["caption"]
        return template.format(concept, attribute_prompt[attribute])

    def hard_negative_prompt(self, concept: str, label: str, original_caption: str) -> str:
        if label == "concept":
            return self.templates["concept_hn"].format(concept, original_caption)
        
        if label == "unaltered":
            return self.templates["unaltered_hn"].format(original_caption)
        
        if label in attribute_hn_prompt:
            t = attribute_hn_prompt[label]
            return self.templates["attribute_hn"].format(t[0], original_caption, t[1], concept)  
        
        raise ValueError(f"Invalid label for hard negative prompt: {label}")

    def classification_prompt(self, concept: str, caption1: str, caption2: str, labels: List[str]) -> str:
        return self.templates["classification"].format("\n".join([f"{l}: {self.label_descriptions[l]}" for l in labels]), concept, caption1, caption2)

    @abstractmethod
    def inference(self, prompts: List[str]) -> List[str]:
        pass

class OllamaModel(BaseLLM):
    def __init__(self, model_id: str, ollama_options: Dict[str, Any] = None, **kwargs):
        self.ollama_options = ollama_options if ollama_options else {}
        super().__init__(model_id, **kwargs)

    def load_model(self):
        print("Loading model by making a placeholder request to the Ollama API...", end=" ")
        generate(model=self.model_id, prompt="Hello, world!", options=self.ollama_options)
        print("Done!")
        return None
    
    def load_tokenizer(self):
        return None

    def inference(self, prompts: List[str]) -> List[str]:
        responses = []
        start_time = time.time()
        for prompt in tqdm(prompts):
            responses.append(generate(model=self.model_id, prompt=prompt, options=self.ollama_options).response)
        total_time = time.time() - start_time
        print(f"Time taken: {total_time:.2f}s, Average time per prompt: {total_time/len(prompts):.2f}s")
        return responses

class VllmModel(BaseLLM):
    def __init__(self, model_id: str, model_type: str, model_kwargs: Dict[str, Any] = None, sampling_kwargs: Dict[str, Any] = None, **kwargs):
        self.model_type = model_type
        self.model_kwargs = model_kwargs if model_kwargs else {}
        self.sampling_kwargs = sampling_kwargs if sampling_kwargs else {}
        super().__init__(model_id, **kwargs)

    def load_model(self):
        if "mistral" in self.model_type:
            self.model_kwargs['tokenizer_mode'] = "mistral"
        print("Loading model with following arguments:")
        print(self.model_kwargs)
        return LLM(self.model_id, **self.model_kwargs)
    
    def load_tokenizer(self):
        return AutoTokenizer.from_pretrained(self.model_id)

    def postprocess(self, caption: str) -> str:
        invalid = False
        if self.model_type == "deepseek":
            if "</think>" in caption:
                caption = caption.split("</think>")[-1]
            else:
                invalid = True
        
        caption = super().postprocess(caption)
        
        if invalid:
            caption += " INVALID"
        return caption

    def get_label(self, output: str, labels: List[str]) -> str:
        if self.model_type == "deepseek":
            if "</think>" in output:
                output = output.split("</think>")[-1].strip()
            else:
                return "ERR"
        
        return super().get_label(output, labels)

    def _apply_chat_template(self, base_prompt: str) -> str:
        return self.tokenizer.apply_chat_template([{"role": "user", "content": base_prompt}], add_generation_prompt=True, tokenize=False)

    def caption_prompt(self, *args, **kwargs) -> str:
        base_prompt = super().caption_prompt(*args, **kwargs)
        return self._apply_chat_template(base_prompt)

    def hard_negative_prompt(self, *args, **kwargs) -> str:
        base_prompt = super().hard_negative_prompt(*args, **kwargs)
        return self._apply_chat_template(base_prompt)

    def classification_prompt(self, *args, **kwargs) -> str:
        base_prompt = super().classification_prompt(*args, **kwargs)
        return self._apply_chat_template(base_prompt)

    def inference(self, prompts: List[str]) -> List[str]:
        sampling_params = SamplingParams(**self.sampling_kwargs)

        start_time = time.time()
        responses = self.model.generate(prompts, sampling_params=sampling_params)
        total_time = time.time() - start_time
        print(f"Time taken: {total_time:.2f}s, Average time per prompt: {total_time/len(prompts):.2f}s")
        
        outputs = [response.outputs[0].text for response in responses]

        if prompts:
            print("Example prompt:", prompts[-1], "\nExample output:\n", outputs[-1])
        return outputs
