import os
import re

import google.generativeai as genai
import torch
from dotenv import load_dotenv
from google.api_core.exceptions import ResourceExhausted
import openai
from abc import ABC, abstractmethod

def api_handler_factory(model_name, temperature):
    if "gemini" in model_name:
        return GeminiHandler(model_name, temperature)
    elif "gpt" in model_name:
        return GPTHandler(model_name, temperature)
    else:
        raise ValueError(f"Unknown handler type: {model_name}")
    
class APIHandler(ABC):
    def __init__(self, model_name, temperature):
        self.model_name = model_name
        self.temperature = temperature
    """ The following abstract methods are model-specific"""
    @abstractmethod
    def load_model(self, *args, **kwargs):
        pass

    @abstractmethod
    def get_response(self, *args, **kwargs):
        pass  

    @abstractmethod
    def query_llm(self, *args, **kwargs):
        pass
    
    """ parse_and_validate_response is model-agnostic -- it's just dealing with a string at this point"""
    def parse_and_validate_response(self, input_string, num_brackets):
        """
        This will parse the response from the model and validate that it
        is a list of floats of length num_brackets between 0 and 1 inclusive.
        """
        input_string = re.sub(r'\$\$', '$', input_string)

        # Define the regex pattern to find the part wrapped in dollar signs
        pattern = r"\$(.*?)\$"

        # Search for the pattern in the input string
        match = re.search(pattern, input_string)

        if match:
            # Extract the part within the dollar signs
            dollar_content = match.group(1).strip()
            dollar_content = dollar_content.replace("*", "")

            # Try to convert the content to a list of floats
            try:
                # Remove any surrounding brackets and split the content by space or comma
                dollar_content = re.sub(r"[\[\]\{\}\(\)]", "", dollar_content)
                float_list = list(map(float, filter(None, re.split(r"[,;\s]+", dollar_content))))
                assert len(float_list) == num_brackets
                print("Valid list of floats:", float_list)
                return float_list
            except ValueError as ve:
                error_msg = "Error: The content within the dollar signs is not a valid list of floats."
                print(error_msg)
                raise ValueError(error_msg) from ve
            except AssertionError:
                print(dollar_content)
                error_msg = f"Error: The list must contain exactly {num_brackets} floats."
                print(error_msg)
                raise ValueError(error_msg)
        else:
            error_msg = "Error: No content found within dollar signs."
            print(error_msg)
            raise ValueError(error_msg)

class GeminiHandler(APIHandler):
    def __init__(self, model_name, temperature):
        super().__init__(model_name, temperature)
        self.model = self.load_model()

    def load_model(self):
        load_dotenv()
        api_key = os.getenv("GEMINI_API_KEY")
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel(self.model_name)
        return model

    def get_response(self, prompt):
        try:
            response = self.model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(temperature=self.temperature),
            )
            return response
        except ResourceExhausted:
            print(f"API key exhausted")
            raise
    

    def query_llm(self, prompt):
        unvalidated = self.get_response(prompt)
        """The output from the API contains more than just the natural language response,
        so we need to extract the text from the response that the model generated."""
        unvalidated_response = unvalidated.candidates[0].content.parts[0].text
        return unvalidated_response

class GPTHandler(APIHandler):
    def __init__(self, model_name, temperature):
        super().__init__(model_name, temperature)
        self.client = self.load_model()
    def load_model(self):
        client = openai.OpenAI(
            api_key=os.getenv("OPENAI_API_KEY"),
        )
        return client

    def get_response(self, prompt):
        try:
            response = self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": prompt,
                    }
                ],
                model=self.model_name,
                temperature=self.temperature,
            )
            #figure out GPT API key exhausted error
        except Exception as e:
            print(f"Error: {e}")
            raise
        return response

    def query_llm(self, prompt):
        response = self.get_response(prompt)
        txt = response.choices[0].message.content
        return txt


def append_to_cleanup_history(incentives, val_obs, mean_rew, generation, length, llm_gets_aux):
    num_cleaned = val_obs['mean_cumulative_num_cleaned']
    num_regrown = torch.sum(val_obs['mean_regrowth_trajectory'].view(-1)).item()
    round_incentives = [round(num, 2) for num in incentives]
    if num_regrown == 0:
        pollution = f"With the pollution level in this episode, no apples grew back. Therefore, 122 apples were available throughout the episode."
    else:
        pollution = f"With the pollution level in this episode, roughly {int(num_regrown)} apples regrew. Therefore, {122+int(num_regrown)} were available throughout the episode."
    if not llm_gets_aux:
        return f"""\nGeneration {generation}: {round_incentives}-> ~ {int(mean_rew)}.\n"""
    else:
        return f"""\nGeneration {generation}: You generated the following incentives -- harvest = {round(round_incentives[0], 2)}, clean = {round(round_incentives[1], 2)}, other = {round(round_incentives[2], 2)} -> Agents cleaned about {int(num_cleaned)} times and did other actions about {7000-int(num_cleaned)-int(mean_rew*7)} times. {pollution} Under these incentives, agent productivity was approximately {int(mean_rew)}. \n"""


def latest_rate_and_return_matching_prompt_vocab(env_name, llm_gets_aux, generation, tax, val_obs, length, mean_reward):
    match env_name:
        case "commons_harvest__open":
            meaning = ""
            if llm_gets_aux:
                if mean_reward < 30:
                    apple_trajectory = val_obs['mean_apple_trajectory'].view(-1)  # Flatten the tensor to 1D
                    zero_indices = (apple_trajectory == 0).nonzero(as_tuple=True)[0]
                    if len(zero_indices) > 0:
                        first_zero_index = zero_indices[0].item()
                        meaning = f"Under this tax rate, agents harvested all apples {int(first_zero_index / length *100)}% of the way through the episode."
                    else:
                        meaning = "Under this tax rate, apples remained unharvested at the end of the episode."
            return f"""Generation {generation}: {tax} -> mean apples: {int(mean_reward)}. {meaning} \n"""
        
        case "clean_up":
            return append_to_cleanup_history(tax, val_obs, mean_reward, generation, length, llm_gets_aux)
