import os
import csv
import base64
from typing import Dict, Any, List, Tuple

######################################### Response Generation #########################################

def get_media_type(image_path: str) -> str:
    """Get the correct media type based on file extension."""
    ext = os.path.splitext(image_path)[1].lower()
    if ext in [".jpg", ".jpeg", ".jpeg"]:
        return "image/jpeg"
    elif ext == ".png":
        return "image/png"
    else:
        raise ValueError(f"Unsupported image extension: {ext}")


def encode_image(image_path: str) -> str:
    """Encode image file to base64 string.

    Args:
        image_path: Path to image file

    Returns:
        str: Base64 encoded image string
    """
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def ensure_dir(path: str) -> None:
    """Ensure directory exists, create if it doesn't.

    Args:
        path: Directory path to ensure exists
    """
    os.makedirs(path, exist_ok=True)


def save_responses_csv(responses: List[Tuple[int, int, str, str, str]], output_file: str):
    """Save responses to CSV file."""
    ensure_dir(os.path.dirname(output_file))
    responses = [[s[0], s[1], s[2].replace("\n", "\\n"), s[3], s[4].replace("\n", "\\n")] for s in responses]
    responses = [["category_id", "task_id", "Prompt", "question", "response"]] + responses
    with open(output_file, mode="w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(responses)

def escape_special_characters(s: str, skip_patten: str = "{question}"):
    """转义字符串中的特殊字符，同时保留 skip_patten 不变"""

    # Replace skip_patten with
    s = s.replace(skip_patten, "_PLACEHOLDER_")

    # Escaping curly braces
    s = s.replace("{", "{{").replace("}", "}}")
    # Escaping backslashes
    s = s.replace("\\", "\\\\")
    # Escape percent signs (if necessary)
    s = s.replace("%", "%%")

    # Replace skip_patten back
    s = s.replace("_PLACEHOLDER_", skip_patten)
    return s

######################################### Format response file #########################################
import os
import pandas as pd

def save_format_descriptions(dataset_csv: str, description_csv: str):
    # Read Table 1 and Table 2
    table1 = pd.read_csv(dataset_csv)
    table2 = pd.read_csv(description_csv)

    # Remove the ".xx" suffix from the index string and split the index string into category_id and task_id
    table2[['category_id', 'task_id']] = table2['Image'].str.replace(r'\.[^.]+$', '', regex=True).str.split('_', expand=True)

    # Convert category_id and task_id to numeric types
    table2['category_id'] = pd.to_numeric(table2['category_id'], errors='coerce')
    table2['task_id'] = pd.to_numeric(table2['task_id'], errors='coerce')

    # Sort table 2 in ascending order by category_id and task_id
    table2.sort_values(by=['category_id', 'task_id'], inplace=True)
    # table1.sort_values(by=['category_id', 'task_id'], inplace=True)

    # Ensure that the data types of the category_id and task_id columns in Table 1 are consistent with those in Table 2.
    table1['category_id'] = pd.to_numeric(table1['category_id'], errors='coerce')
    table1['task_id'] = pd.to_numeric(table1['task_id'], errors='coerce')

    # Merge Table 1 and Table 2 based on category_id and task_id
    # Assume that the column name for the question in Table 1 is 'question'
    result = pd.merge(table2, table1[['category_id', 'task_id', 'question']],
                      on=['category_id', 'task_id'], how='left')

    result = result[['category_id', 'task_id', 'Image', 'Prompt', 'question', 'Description']]

    # Save the results to a new Excel file
    result.to_csv(description_csv, index=False)


# Example
# dataset_csv = './SafeBench-Car-2.0.csv'
# description_csv = './description_Janus-1.3B.csv'
# save_format_descriptions(dataset_csv, description_csv)


def save_format_descriptions_HCoT(dataset_csv: str, description_csv: str):
    # Read two CSV files
    df1 = pd.read_csv(description_csv)
    df2 = pd.read_csv(dataset_csv)

    # 1. Remove the line break in the prompt
    df1['Prompt'] = df1['Prompt'].str.replace('\\n', '\n', regex=False)

    # 2. Extract the required columns from df2 and remove duplicates (to prevent duplicate matches)
    df2_subset = df2[['category_id', 'task_id', 'instruction']].drop_duplicates()

    # 3. Merge df1 and df2_subset based on category_id and task_id
    df_merged = pd.merge(df1, df2_subset, on=['category_id', 'task_id'], how='left')

    # 4. Replace the question column with the value of the instruction column
    df_merged['question'] = df_merged['instruction']

    # 5. Delete the temporary instruction column (optional)
    df_merged.drop(columns=['instruction'], inplace=True)

    # Construct output name
    base, ext = os.path.splitext(description_csv)
    output_csv = f"{base}_formatted{ext}"

    # 6. Save the results to a new file
    df_merged.to_csv(output_csv, index=False)

    # print(f"Save the results to: {output_csv}")


# Example
# dataset_csv = './data/MaliciousEducator-M2-Filtered-Hcot-O3.csv'
# description_csv = './data/responses_o3-2025-04-16_D3_1.csv'
# save_format_descriptions_HCoT(dataset_csv, description_csv)