#!/usr/bin/env python3

import argparse
import os
import re
import sys
import pandas as pd
import logging
import json
import copy
import unicodedata


def parse_arguments():
    """
    Parse command-line arguments.

    Returns:
        argparse.Namespace: Parsed arguments.
    """
    parser = argparse.ArgumentParser(
        description="Generate prompt templates for a list of concepts."
    )
    parser.add_argument(
        '--config',
        type=str,
        required=True,
        help='Path to the JSON configuration file specifying templates, input files, number of concepts, and prompts per concept.'
    )
    parser.add_argument(
        '-o', '--output_dir',
        type=str,
        default=None,
        help='Directory where the output CSV will be saved. Defaults to "prompts". Defaults to directory of the config.'
    )
    parser.add_argument(
        '-f', '--output_filename',
        type=str,
        default=None,
        help='Name of the output CSV file. Defaults to config filename with "config.json" replaced by "prompts.csv".'
    )
    parser.add_argument(
        '--verbose',
        action='store_true',
        help='Enable verbose logging.'
    )
    return parser.parse_args()


def setup_logging(verbose: bool):
    """
    Configure logging settings.

    Args:
        verbose (bool): If True, set logging level to DEBUG.
    """
    level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout)
        ]
    )


def sanitize_concept_name(raw_name: str) -> str:
    """
    Sanitize the concept name by removing unwanted characters.

    Args:
        raw_name (str): Raw concept name.

    Returns:
        str: Sanitized concept name.
    """
    sanitized = raw_name.split("_[")[0].replace("_", " ").strip()

    print(f"Sanitized '{raw_name}' to '{sanitized}'.")
    return sanitized


def escape_unicode(x: str) -> str:
    # escape unicode
    sanitized = "".join(f"<U+{ord(c):04X}>" if unicodedata.category(c) in ["Cf", "Cc"] else c for c in x)
    return sanitized


def generate_prompts(concept: str, templates: list, n_prompts: int, is_trigger: bool) -> list:
    """
    Generate prompts for a concept based on templates.

    Args:
        concept (str): Concept to generate prompts for.
        templates (list): List of templates.
        n_prompts (int): Number of prompts to generate per template.

    Returns:
        list: List of generated prompts.
    """
    prompts = []
    for template in templates:
        for seed in range(1, n_prompts + 1):

            prompt = template.replace('<concept>', concept)
            if is_trigger:
                prompt = prompt.replace('<trigger>', concept)
            else:
                prompt = prompt.replace('<trigger>', '').strip()

            prompts.append((seed, prompt))
    return prompts


def write_prompts_to_csv(output_path: str, prompts: list):
    """
    Write generated prompts to a CSV file.

    Args:
        output_path (str): Path to the output CSV file.
        prompts (list): List of prompt entries to write.
    """
    try:
        df = pd.DataFrame(prompts, columns=['id', 'category', 'seed', 'concept', 'prompt', 'n_samples'])
        df.to_csv(output_path, index=False)
        logging.info(f"Successfully wrote prompts to {output_path}.")
    except Exception as e:
        logging.error(f"Failed to write to CSV: {e}")
        raise


def read_concepts_from_file(input_file, n_concepts=None):
    if not input_file or not os.path.exists(input_file):
        logging.warning(f"Input file {input_file} does not exist. Skipping.")
        return None

    # Read concepts from input file
    try:
        with open(input_file, "r") as f:
            concepts = [line.strip() for line in f.readlines() if line.strip()]
            if n_concepts:
                concepts = concepts[:n_concepts]
    except Exception as e:

        logging.error(f"Failed to read concepts from {input_file}: {e}")

    return concepts


def main():
    # Parse command-line arguments
    args = parse_arguments()

    # Setup logging
    setup_logging(args.verbose)

    logging.info("Starting prompt generation process.")

    if not args.output_dir:
        args.output_dir = os.path.dirname(args.config)

    if not args.output_filename:
        args.output_filename = os.path.basename(args.config).replace("config.json", "prompts.csv")

    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)
    output_file_path = os.path.join(args.output_dir, args.output_filename)
    logging.debug(f"Output file will be saved to {output_file_path}.")

    # Load configuration
    try:
        with open(args.config, 'r') as config_file:
            config = json.load(config_file)
    except Exception as e:
        logging.critical(f"Failed to load configuration file: {e}")
        sys.exit(1)

    base_prompts = config.get("base_prompts", None)
    if base_prompts:
        base_prompts_df = pd.read_csv(base_prompts)
    else:
        base_prompts_df = None

    # Extract templates
    templates = config.get("templates", [])
    if not templates:
        logging.critical("No templates found in the configuration file.")
        sys.exit(1)

    # Initialize list to store prompt entries
    prompt_entries = []
    id_counter = 0

    all_categories = config.get("categories", [])
    print("ALL CATEGORIES:", all_categories)

    target_concepts = None
    for category in all_categories:
        category_name = category.get('category_name', 'default')

        input_file = category.get('input_file')

        n_concepts = category.get('n_concepts', None)
        n_prompts = category.get('n_prompts', 1)
        n_samples = category.get('n_samples', 1)

        is_trigger = category.get('is_trigger', False)
        trigger_target = category.get('trigger_target', None)
        sync_with_target = category.get('sync_with_target', False)

        if sync_with_target:
            assert target_concepts, "Make sure that with sync_with_targets that the 'target' group comes before the synced group!"

        static_concept_name = category.get('concept_name', None)

        if not input_file or not os.path.exists(input_file):
            logging.warning(f"Input file {input_file} does not exist. Skipping.")
            continue

        logging.info(f"Processing category '{category_name}' from file: {input_file}")

        concepts = read_concepts_from_file(input_file, n_concepts=n_concepts)

        print("Read concepts:", category_name, concepts)

        if category.get('split_multiple', False):

            if category.get('include_multiple', False):
                multiple_concepts = [concept for concept in concepts if len(concept.split(',')) > 0]
            else:
                multiple_concepts = []

            concepts = [c.strip() for concept in concepts for c in concept.split(",") if c.strip()]

            if category.get('include_multiple', False):
                concepts += multiple_concepts
                concepts = list(set(concepts))

            # deduplicate
            concepts = list(set(concepts))

        if category_name == 'target':
            target_concepts = concepts
            print("TARGETS:", target_concepts)

        if is_trigger:
            print(f"You specified is_trigger=True for this category ({category_name}! This means the concept (not the one inserted into the prompt template) will be changed to the specified trigger_target: {trigger_target} ...")

        # Generate prompts for each concept
        for i, concept in enumerate(concepts):

            sanitized_concept = sanitize_concept_name(concept)
            generated_prompts = generate_prompts(sanitized_concept, templates, n_prompts, is_trigger)

            # encode unicode characters (after generating prompts)
            sanitized_concept = escape_unicode(sanitized_concept)

            if base_prompts_df is not None and not len(base_prompts_df) == 0:
                generated_prompts = [(seed, prompt.replace("<prompt>", base_prompt)) for _, prompt in generated_prompts for seed, base_prompt in zip(base_prompts_df.seed.to_list(), base_prompts_df.prompt.to_list())]

            for idx, (seed, prompt) in enumerate(generated_prompts):

                if base_prompts_df is not None:
                    if base_prompts:
                        restrict_to = config.get("restrict_base_prompts_to", [])

                        # Drop all rows that do not overlap with the restrict_to list in their 'concept' column
                        if restrict_to:
                            pattern = "|".join(map(re.escape, restrict_to))  # Create regex pattern for matching
                            if not re.search(pattern, base_prompts_df.concept[idx], re.IGNORECASE):
                                print(base_prompts_df.concept[idx])
                                id_counter += 1
                                continue

                category_name_modified = category_name if not is_trigger else f"{category_name}_{sanitized_concept}".replace(" ", "_").replace(",", "").strip()
                sanitized_concept_modified = sanitized_concept if not is_trigger else trigger_target

                if sync_with_target:
                    sanitized_concept_modified = target_concepts[i]

                if static_concept_name:
                    sanitized_concept_modified = static_concept_name

                prompt_entries.append([id_counter, category_name_modified, seed, sanitized_concept_modified, prompt, n_samples])
                id_counter += 1

        if target_concepts:
            all_base_prompt_entries = copy.deepcopy(prompt_entries)
            found_replacement = False

            for target in target_concepts:
                for prompt_entry in all_base_prompt_entries:

                    if prompt_entry[-3] == "<target>":
                        new_entry = copy.deepcopy(prompt_entry)
                        new_entry[-3] = target
                        prompt_entries.append(new_entry)
                        found_replacement = True

            if found_replacement:
                print("[IMPORTANT] You will probably want to use --automatic_restriction in the run_evaluation_pipeline.py script to only evaluate the relevant prompts for the particular model (trigger-target combination)!")

        # delete all <target> entries
        prompt_entries = [entry for entry in prompt_entries if entry[-3] != "<target>"]

        # Write all prompts to the output CSV
        try:
            write_prompts_to_csv(output_file_path, prompt_entries)
        except Exception as e:
            logging.critical(f"Failed to write prompts to CSV: {e}")
            sys.exit(1)

        logging.info("Prompt generation process completed successfully.")


if __name__ == "__main__":
    main()
