from fastchat.model import (
    get_conversation_template
)
import re
from judge_prompts import get_judge_system_prompt, get_prompt_judge_system_prompt
import openai
import anthropic
import os
import time
import torch
import gc
from typing import Dict, List
import google.generativeai as palm


def load_judge(args, goal, target):
    if "gpt" in args.judge_model:
        return GPTJudge(args, goal, target)
    elif args.judge_model == "no-judge":
        return NoJudge(args)
    else:
        raise NotImplementedError

class LanguageModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError

class GPT(LanguageModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    api_key = "<YOUR_API_KEY>"

    def generate(self, conv: List[Dict], 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        client = openai.OpenAI(api_key="<YOUR_API_KEY>")
        for _ in range(self.API_MAX_RETRY):
            try:
                response = client.chat.completions.create(
                            model = self.model_name,
                            messages = conv,
                            max_tokens = max_n_tokens,
                            temperature = temperature,
                            top_p = top_p,
                            #request_timeout = self.API_TIMEOUT,
                            )
                output = response.choices[0].message.content
                break
            except openai.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]



class JudgeBase:
    def __init__(self, args, goal, target):
        self.max_n_tokens = args.judge_max_n_tokens
        self.temperature = args.judge_temperature
        self.judge_name = args.judge_model
        self.system_prompt = get_judge_system_prompt(goal, target)
        self.system_prompt_similarity = get_prompt_judge_system_prompt(goal, target)

    def get_judge_prompt(self, attack_prompt, target_response):
        return f"[PROMPT]:{attack_prompt}\n[ASSISTANT'S RESPONSE]: {target_response}"

    def score(self, prompt_list, response_list):
        raise NotImplementedError

    def process_output(self, raw_output):
        pattern = r'\[\[(\d+)\]\]'
        match = re.search(pattern, raw_output)
        output = int(match.group(1)) if match else None
        if output is None:
            print(f"Error in processing judge output: {raw_output}" )
            output = 1
        return output
    def process_output_plain(self, raw_output):
        output = int(raw_output) if raw_output else None
        if output is None:
            print(f"Error in processing judge output: {raw_output}" )
            output = 1
        return output
               
class NoJudge(JudgeBase):
    def __init__(self, args):
        super(NoJudge, self).__init__(args)
        self.judge_model = None

    def score(self, prompt_list, response_list):
        return [1 for _ in prompt_list]

class GPTJudge(JudgeBase):
    def __init__(self, args, goal, target):
        super(GPTJudge, self).__init__(args, goal, target)
        self.judge_model = GPT(model_name = self.judge_name)

    def create_conv(self, full_prompt):
        conv = get_conversation_template(self.judge_name)
        conv.set_system_message(self.system_prompt)
        conv.append_message(conv.roles[0], full_prompt)
        return conv.to_openai_api_messages()

    def create_conv_sim(self, full_prompt):
        conv = get_conversation_template(self.judge_name)
        conv.set_system_message(self.system_prompt_similarity)
        conv.append_message(conv.roles[0], full_prompt)
        return conv.to_openai_api_messages()

    def score(self, attack_prompt_list, target_response_list):
        convs_list = [self.create_conv(self.get_judge_prompt(prompt, response)) for prompt, response in zip(attack_prompt_list, target_response_list)]
        raw_outputs = self.judge_model.batched_generate(convs_list, 
                                                        max_n_tokens = self.max_n_tokens,
                                                        temperature = self.temperature)
        outputs = [self.process_output(raw_output) for raw_output in raw_outputs]
        return outputs

    def score_sim(self, attack_prompt_list, target_response_list):
        convs_list = [self.create_conv_sim(self.get_judge_prompt(prompt, response)) for prompt, response in zip(attack_prompt_list, target_response_list)]
        raw_outputs = self.judge_model.batched_generate(convs_list, 
                                                        max_n_tokens = self.max_n_tokens,
                                                        temperature = self.temperature)
        outputs = [self.process_output(raw_output) for raw_output in raw_outputs]
        return outputs

class OpenSourceJudge(JudgeBase):
    def __init__(self, judge_model, judge_tokenizer, args):
        # TODO: Implement open source judge
        raise NotImplementedError