import os
from typing import Dict, Any, List, Tuple
from tqdm import tqdm
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)
from openai import OpenAI

from Templates.User_Templates import USER_TEMPLATES
from Templates.Developer_Messages import DEVELOPER_MESSAGES
from Tools.Response_Generation import *
from Tools.Calculating_Prompt_Toxicity import read_prompts_csv

import argparse
from omegaconf import OmegaConf

# Read config
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./config/cfg_Generate_Response.yaml")
args = parser.parse_args()

config_API = OmegaConf.load("./config/cfg_API.yaml")
config = OmegaConf.load(args.config)
config = OmegaConf.merge(config_API, config)

# Settings
base_url = config.base_url # Set your openai base_url
api_key = config.api_key # Set your api_key
model_name = config.model_name # Victim Model
input_file = config.input_file # File of malicious questions
output_dir = config.output_dir # Path to save responses
user_template = USER_TEMPLATES[config.user_template] # Can be "Vanilla" (DH-CoT, fill H-CoT by datasets), "Car_Answer" (D-Attack), Car_Modify
developer_message = DEVELOPER_MESSAGES[config.developer_message] # Can be "None", "Normal", "D1" (D-Attack), "D2", ..., "D9" (H-CoT), "D10" (H-CoT)


def setup_proxy_model(base_url: str, api_key: str):
    return OpenAI(
        api_key=api_key,
        base_url=base_url
    )


class ResponseGeneratorByProxy:
    def __init__(self, model_name: str, base_url: str, api_key: str):
        self.model_name = model_name

        # Setup API Client
        self.client = setup_proxy_model(base_url, api_key)


    def generate_description(self, prompt: str, developer_message: str) -> str:
        return self._generate_proxy_model(prompt, developer_message, self.model_name)


    @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
    def _generate_proxy_model(self, prompt: str, developer_message: str, model_name: str) -> str:
        if developer_message is not None:
            response = self.client.chat.completions.create(
                model=model_name,
                messages=[
                    {
                        "role": "developer",
                        "content": developer_message
                    },
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt,
                            },
                        ],
                    }
                ],
                # max_tokens=512,
            )
        else:
            response = self.client.chat.completions.create(
                model=model_name,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt,
                            },
                        ],
                    }
                ],
                # max_tokens=512,
            )

        # Check whether the victim model is valid
        if response.model is not None:
            return response.choices[0].message.content
        else:
            raise ValueError(f"Unsupported victim model: {model_name}")


def main(input_file: str, output_dir: str, model_name: str, base_url: str, api_key: str, prompt_template: str, developer_message: str = "", template_name: str = ""):
    print(f"Using victim model: {model_name}")


    # Get dir form output paths
    ensure_dir(output_dir)

    # Process input questions
    responses = []
    try:
        # Initialize description generator
        generator = ResponseGeneratorByProxy(model_name=model_name, base_url=base_url, api_key=api_key)

        # Read prompts
        text_prompts = read_prompts_csv(input_file)

        # Generate descriptions
        print("Processing questions...")
        for (dataset, category_id, task_id, category_name, question, instruction) in tqdm(text_prompts):
            prompt = prompt_template.format(question = question)
            try:
                response = generator.generate_description(prompt, developer_message)
                if response is None:
                    response = ""
                responses.append(
                    (category_id, task_id, prompt, question, response)
                )

            except Exception as e:
                responses.append(
                    (category_id, task_id, prompt, question, "")
                )
                print(f"Error processing question {category_id}_{task_id}: {e}")

        # Save responses
        save_responses_csv(
            # sorted(responses, key=lambda x: x[0]),
            responses,
            os.path.join(
                output_dir, f"responses_{model_name}.csv" if template_name is None else f"{template_name}/responses_{model_name}.csv"
            ),
        )

        print(f"Responses saved to {output_dir}")

    except (FileNotFoundError, KeyError) as e:
        print(f"Error: {e}")
        return

    return responses


responses = []
if __name__ == "__main__":
    responses = main(
        input_file = input_file,
        output_dir = output_dir,
        model_name = model_name,
        base_url = base_url,
        api_key = api_key,
        prompt_template = escape_special_characters(user_template, skip_patten = "{question}"),
        developer_message = developer_message,
        template_name = None
    )