import anthropic
import json
import os
import re
from openai import OpenAI

from copy import deepcopy
from memgpt.constants import PROMPTS_DIR, CONFIGS_DIR
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel
from tqdm import tqdm
from typing import List, Optional
from torch.utils.data import DataLoader
from memgpt.utils import truncate_prompt


class Prompt:
    def __init__(self, prompt_id):
        self.prompt_id = prompt_id
        path = os.path.join(PROMPTS_DIR, f"{prompt_id}.json")
        if not os.path.exists(path):
            raise FileNotFoundError(f"Prompt file not found at {path}.")
        with open(path, "r") as f:
            self.prompt = json.load(f)

    def __call__(self, text):
        """
        Fills the placeholders of the form [INSERT_TEXT] in the prompt with the provided text.

        Args:
            text (str): The text to insert into the prompt.

        Returns:
            list: A list of dictionaries with the 'content' field updated to include the provided text.
        """
        filled_prompt = deepcopy(self.prompt)
        for prompt_dict in filled_prompt:
            try:
                prompt_dict['content'] = re.sub(
                    r'\[INSERT_TEXT\]',
                    text,  # Remove the lambda function - it's unnecessary here
                    prompt_dict['content']
                )
            except re.error as e:
                print(f"Regex error occurred: {e}")
                print(f"Problematic text: {text}")
                # BUG: re.error: bad escape \d at position 4616 (line 31, column 12)
                if 'INSERT_TEXT' in filled_prompt[-1]['content']:
                    filled_prompt[-1]['content'] = text
        return filled_prompt

class Annotator:
    def __init__(self, model_id, prompt_id, config_file):
        self.model_id = model_id
        self.prompt = Prompt(prompt_id)
        if config_file is not None:
            with open(os.path.join(CONFIGS_DIR, f"{config_file}.json"), "r") as f:
                self.configs = json.load(f)
        else:
            self.configs = {}

    def annotate(self, texts):
        raise NotImplementedError

    def postprocess(self, texts):
        raise NotImplementedError

class ChatGPTAnnotator(Annotator):
    def __init__(self, model_id, prompt_id, config_file):
        super().__init__(model_id, prompt_id, config_file)
        self.client = OpenAI()

    def annotate(self, texts):
        annotations = []
        for text in texts:
            message = self.prompt(text)

            annotations.append(
                self.client.chat.completions.create(
                model=self.model_id,
                messages=message
            ).choices[0].message.content
            )
        return annotations

    def postprocess(self, text, annotated_texts):
        raise NotImplementedError

class ClaudeAnnotator(Annotator):
    # "claude-3-5-sonnet-20241022"
    def __init__(self, model_id, prompt_id, config_file):
        super().__init__(model_id, prompt_id, config_file)
        self.claude = anthropic.Anthropic()

    def annotate(self, texts):
        annotations = []
        for text in texts:
            annotations.append(
                self.claude.messages.create(
                    model=self.model_id,
                    max_tokens=self.configs['max_tokens'],
                    messages=[
                        {"role": "user", "content": self.prompt(text)}
                    ]
                )["content"]["text"]
            )
        return annotations

    def postprocess(self, text, annotated_texts):
        raise NotImplementedError

class LlamaAnnotator(Annotator):
    def __init__(self, model_id, prompt_id, config_file):
        super().__init__(model_id, prompt_id, config_file)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.llm = LLM(model=self.model_id, **self.configs['llm'])
        self.sampling_params = SamplingParams(**self.configs['sampling'])

    def annotate(self, texts):
        prompts = [
            self.tokenizer.apply_chat_template(
                self.prompt(text),
                tokenize=False,
                add_generation_prompt=True,
                truncation=True,
                max_length=1024
            ) for text in texts
        ]

        prompts = [truncate_prompt(p, max_tokens=1024, tokenizer=self.tokenizer) for p in prompts]
        
        try:
            responses = self.llm.generate(prompts, self.sampling_params)
            annotated_texts = list(map(lambda x: x.outputs[0].text, responses))
            return annotated_texts
        except Exception as e:
            print(f"Error occurred: {e}")
            return []
            
    def postprocess(self, text, annotated_texts):
        raise NotImplementedError

    
class LlamaLoraAnnotator(Annotator):
    def __init__(self, model_id, prompt_id, config_file):
        super().__init__(model_id, prompt_id, config_file)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.llm = LLM(model=self.configs['base_model'], **self.configs['llm'])
        self.sampling_params = SamplingParams(**self.configs['sampling'])

    def annotate(self, texts=None, prompts=None):
        assert prompts or texts, "Either `texts` or `prompts` must be provided."
        if prompts is None:
            prompts = [
                self.tokenizer.apply_chat_template(
                    self.prompt(text),
                    tokenize=False,
                    add_generation_prompt=True
                ) for text in texts
            ]
        responses = self.llm.generate(prompts, self.sampling_params, lora_request=LoRARequest("lora_adapter", 1, self.model_id), use_tqdm=True)
        annotated_texts = list(map(lambda x: x.outputs[0].text, responses))
        return annotated_texts

    def postprocess(self, text, annotated_texts):
        raise NotImplementedError

class LlamaLoraAnnotator_hf(Annotator):
    def __init__(self, model_id: str, prompt_id: str, config_file: dict):
        """
        Initializes the Hugging Face-based Llama LoRA annotator.
        
        :param model_id: Path to the fine-tuned LoRA model directory.
        :param prompt_id: Identifier for the prompt template.
        :param config_file: Dictionary with model, tokenizer, and generation settings.
        """
        super().__init__(model_id, prompt_id, config_file)
        # Load tokenizer (ensuring it has the correct vocabulary)
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)

        # Load base model and apply LoRA adapter
        base_model = AutoModelForCausalLM.from_pretrained(
            self.configs['base_model'],
            device_map="auto"
        )
        if len(self.tokenizer) != base_model.config.vocab_size:
            # TODO: probably need to add a new pad token
            if self.tokenizer.pad_token_id is None:
                self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

            # Resize token embeddings (if vocab was modified)
            base_model.config.pad_token_id = self.tokenizer.pad_token_id
            base_model.resize_token_embeddings(len(self.tokenizer))

            print(f"Added pad token to tokenizer: {self.tokenizer.pad_token} and resized token embeddings to {len(self.tokenizer)}")

        
        self.model = PeftModel.from_pretrained(base_model, self.model_id)

        # Move model to GPU
        self.model = self.model.to("cuda")
        self.model.eval()

        # Set up generation parameters
        self.generation_config = GenerationConfig(**self.configs['generation_config'])

        self.bsz = self.configs.get("batch_size", 16)


    def annotate(self, texts: Optional[List[str]] = None, prompts: Optional[List[str]] = None) -> List[str]:
        """
        Annotates input texts using the fine-tuned Llama LoRA model.
        
        :param texts: List of raw texts to be processed.
        :param prompts: Preformatted prompts (if available).
        :return: List of annotated responses.
        """
        assert prompts or texts, "Either `texts` or `prompts` must be provided."

        # If texts are provided, format them using a chat template
        if prompts is None:
            prompts = [
                self.tokenizer.apply_chat_template(
                    self.prompt(text),
                    tokenize=False,
                    add_generation_prompt=True
                ) for text in texts
            ]


        # Tokenize all prompts at once (ensuring padding is enabled for batching)
        inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")

        responses = []

        # Create batches
        dataloader = DataLoader(range(len(prompts)), batch_size=self.bsz, shuffle=False)

        for batch_indices in tqdm(dataloader, desc="Generating Responses"):
            with torch.no_grad():
                # Select batch input
                batch_input_ids = inputs["input_ids"][batch_indices]
                batch_attention_mask = inputs["attention_mask"][batch_indices]

                # Generate outputs in batch
                batch_outputs = self.model.generate(
                    input_ids=batch_input_ids,
                    attention_mask=batch_attention_mask,
                    generation_config=self.generation_config,
                )

                # Remove prompt from generated text for each batch element
                for i, output in enumerate(batch_outputs):
                    prompt_length = len(batch_input_ids[i])  # Length of original prompt
                    response_text = self.tokenizer.decode(output[prompt_length:], skip_special_tokens=True)
                    responses.append(response_text)

        return responses


class AnnotatorRevisor(Annotator):
    def __init__(self, model_id, initial_prompt_id, revision_prompt_id, config_file):
        super().__init__(model_id, initial_prompt_id, config_file)
        self.revision_prompt = Prompt(revision_prompt_id)

    def annotate(self, texts):
        raise NotImplementedError

    def postprocess(self, texts):
        raise NotImplementedError

class ClaudeAnnotatorRevisor(AnnotatorRevisor):
    # "claude-3-5-sonnet-20241022"
    def __init__(self, model_id, initial_prompt_id, revision_prompt_id, config_file):
        super().__init__(model_id, initial_prompt_id, revision_prompt_id, config_file)
        self.claude = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

    def annotate(self, texts):
        annotations = []
        revisions = []
        for text in texts:
            print(f"Annotating... ", end="")
            prompt = self.prompt(text)
            # print(f"Prompt: {prompt}")
            annotation = self.claude.messages.create(
                model=self.model_id,
                max_tokens=self.configs['max_tokens'],
                messages=prompt
            ).content[0].text
            # print(f"Annotation: {annotation}")
            annotations.append(annotation)
            print(f"Revising...", end="")
            revision_prompt = prompt+[{"role": "assistant", "content": annotation}]+self.revision_prompt('')
            # print(f"Revision Prompt: {revision_prompt}")
            revision = self.claude.messages.create(
                model=self.model_id,
                max_tokens=self.configs['max_tokens'],
                messages=revision_prompt
            ).content[0].text
            # print(f"Revision: {revision}")
            revisions.append(revision)
            print(f"Done.")
        return annotations, revisions

    def postprocess(self, text, annotated_texts):
        raise NotImplementedError

if __name__ == "__main__":
    print("Testing the Prompt class.")
    prompt = Prompt("claude-v0")
    results = prompt("This is some random sample text to test the class.")
    print(results)
    print("\n\n")

    print("Testing the LlamaAnnotator class.")
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    llama_annotator = LlamaAnnotator(model_id, "claude-v0", "llama/default")
    texts = [
        "Friedrich Wilhelm Nietzsche (15 October 1844 – 25 August 1900) was a German classical scholar, philosopher, and critic of culture, who became one of the most influential of all modern thinkers.",
        "Theodore Roosevelt Jr. (October 27, 1858 – January 6, 1919), also known as Teddy or T. R., was the 26th president of the United States, serving from 1901 to 1909."
    ]
    annotated_texts = llama_annotator.annotate(texts)
    for text in annotated_texts:
        print(text)
