import anthropic
import google.generativeai as genai

class GeminiModels():
    def __init__(self, config):

        self.api_key = self.get_api_key(config)        
        genai.configure(api_key=self.api_key)        
        
        self.model_name = config.get("model_name")

        self.model = genai.GenerativeModel(self.model_name)
        self.max_tokens = config.get("max_tokens", 1000)
        self.num_retries = config.get("num_retries", 5)
        self.temp = 0.0

        self.gen_config = genai.GenerationConfig(
            max_output_tokens = self.max_tokens,
            temperature = 0.0,
        )


    def get_api_key(self, config):
        """
        either api_key (str) or secret_key_params (dict) must be provided in the config.
        if api_key is not provided, secret_key_params must be provided to get the api_key using GetKey method.
        """
        try:
            api_key = config["api_key"]
        except KeyError:
            if "secret_key_params" not in config:
                raise ValueError(
                    "Either api_key (str) or secret_key_params (dict) must be provided to {self.__class__.__name__}"
                )
            # api_key = GetKey(**config["secret_key_params"])
        return api_key
    
    
    def name(self):
        return self.model_name
    
    def jsonToQueryText(self, messages):
        query_text = ""
        for message in messages:
            if(message['role'] == 'system'):
                query_text += "System: " + message["content"] + "\n"
            elif message["role"] == "user":
                query_text += "User: " + message["content"] + "\n"
            elif message["role"] == "assistant":
                query_text += "Assistant: " + message["content"] + "\n\n"
        query_text += "Assistant: "
        return query_text.strip()

    

    def generate(self, query_text, query_images=None, system_message=None):

        attempts = 0
        query_text = self.jsonToQueryText(query_text)

        while attempts < self.num_retries:
            try:

                if query_images:
                    gemini_response = self.model.generate_content([query_text] + query_images, generation_config=self.gen_config)
                else:
                    gemini_response = self.model.generate_content(query_text) 
                    
                response = gemini_response.text
                break

            except Exception as e:
                print(e)
                # logging.warning(f"Attempt {attempts+1}/{self.num_retries} failed: {e}")
                response = None
                return response, False

                attempts += 1
        else:
            # logging.warning("All attempts failed.")
            response = None
            return response, False

        return response, True




class CLAUDEModels():
    def __init__(self, config):

        self.api_key = self.get_api_key(config)
        self.client = anthropic.Anthropic(            
            api_key=self.api_key,
        )
        self.model_name = config.get("model_name")
        self.max_tokens = config.get("max_tokens", 1000)
        self.num_retries = config.get("num_retries", 5)

        self.temperature = 0
        # print(self.api_key)
        # print(self.model_name)


    def get_api_key(self, config):
        """
        either api_key (str) or secret_key_params (dict) must be provided in the config.
        if api_key is not provided, secret_key_params must be provided to get the api_key using GetKey method.
        """
        try:
            api_key = config["api_key"]
        except KeyError:
            if "secret_key_params" not in config:
                raise ValueError(
                    "Either api_key (str) or secret_key_params (dict) must be provided to {self.__class__.__name__}"
                )
            # api_key = GetKey(**config["secret_key_params"])
        return api_key


    def name(self):
        return self.model_name

    def generate(self, query_text, query_images=None, system_message=None):
        # Create an openai_request
        result = {}
        result["claude_request"], system_message = self.create_claude_request(query_text, query_images, system_message)

        attempts = 0

        while attempts < self.num_retries:
            try:
                completion = self.client.messages.create(
                    model=self.model_name,
                    system=system_message,
                    **result["claude_request"],
                    temperature=self.temperature,
                )                
                response = completion.content[0].text
                # print(response)            
                break

            except Exception as e:
                print(e)
                # logging.warning(f"Attempt {attempts+1}/{self.num_retries} failed: {e}")
                response = None
                return response, False

                attempts += 1
        else:
            # logging.warning("All attempts failed.")
            response = None
            return response, False

        # time.sleep(20)

        return response, True

    def create_claude_request(self, messages, query_images=None, system_message=None):
        system_message = messages[0]["content"]
        
        messages = messages[1:]
        
        # if query_images:
        #     encoded_images = query_images #self.base64encode(query_images)
        #     if system_message:
        #         messages = [
        #             {"role": "system", "content": system_message},
        #             {
        #                 "role": "user",
        #                 "content": [
        #                     {"type": "text", "text": prompt},
        #                     {
        #                         "type": "image",
        #                         "source": {
        #                             "type": "base64",
        #                             "media_type": "image/jpeg",
        #                             "data": encoded_images[0],
        #                         },
        #                     },
        #                 ],
        #             },
        #         ]
        #     else:
        #         messages = [
        #             {
        #                 "role": "user",
        #                 "content": [
        #                     {"type": "text", "text": prompt},
        #                     {
        #                         "type": "image",
        #                         "source": {
        #                             "type": "base64",
        #                             "media_type": "image/jpeg",
        #                             "data": encoded_images[0],
        #                         },
        #                     },
        #                 ],
        #             }
        #         ]
        # else:
        #     if system_message:
        #         messages = [{"role": "system", "content": system_message}, {"role": "user", "content": prompt}]
        #     else:
        #         messages = [{"role": "user", "content": prompt}]
        return {"messages": messages, "max_tokens": self.max_tokens}, system_message



########################



# if "gemini" in args.model:
#     from utils import GeminiModels, gemini_config_vision

#     LMMmodels = GeminiModels(gemini_config_vision)

# elif "claude" in args.model:
#     from utils import CLAUDEModels, claude_config

#     LMMmodels = CLAUDEModels(claude_config)


# def model_gemini(img_path, prompt_text, answer_set, answer, file):
#     img_path = os.path.join(os.getcwd(), img_path)
#     encoded_image = base64.b64encode(open(img_path, 'rb').read()).decode('ascii')

#     response = LMMmodels.generate(prompt_text, [encoded_image]) 
    
#     predicted_answer = response[0].strip()
#     predicted_answer = predicted_answer.replace("'","") # remove single quotes, not sure why it is appearing
#     predicted_answer = predicted_answer.replace(' ','') # remove spaces         
   
#     answer = answer.replace(' ','') # remove spaces
#     judgement = int(predicted_answer == answer)
#     result = {
#         'img_path': img_path,
#         'prompt_text': prompt_text,
#         'options': answer_set,
#         'ground_truth': answer,
#         'prediction': predicted_answer,
#         'response': response[0],
#         'judgement': judgement
#     }
#     file.write(json.dumps(result) + '\n')
    
#     return judgement