import tiktoken
import random
import argparse
from typing import List, Tuple, Dict, Set
import json
from pathlib import Path
import logging
from src.utils import save_mapping, load_mapping, print_token_mapping

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

TIMESTEP_TOKEN = "TIMESTEP"


def text2tokens(text: str, tokenizer: tiktoken.Encoding) -> List[int]:
    return [tokenizer.decode([token_id]) for token_id in tokenizer.encode(text)]


def is_lowercase_alpha(string: str) -> bool:
    return all((ord(c) >= ord("a") and ord(c) <= ord("z")) for c in string)


def get_space_prefixed_tokens(
    model_name: str = "gpt-4o", substitution_strategy: str = "one2many"
) -> List[Tuple[int, str]]:
    """Get all tokens that start with a space and end with one or more lowercase alphabetic characters from the model's tokenizer."""
    encoding = tiktoken.encoding_for_model(model_name)
    all_tokens = []
    for token_id in range(encoding.n_vocab):
        try:
            token_bytes = encoding.decode([token_id])
            if "sym" in substitution_strategy:
                if (
                    len(token_bytes) == 2
                    and token_bytes[0] == " "
                    and is_lowercase_alpha(token_bytes[1])
                ):
                    all_tokens.append((token_id, token_bytes))
            else:
                if token_bytes[0] == " " and is_lowercase_alpha(token_bytes[1:]):
                    all_tokens.append((token_id, token_bytes))
        except:
            continue
    return all_tokens


def get_excluded_tokens(prompt_text: str, model_name: str = "gpt-4o") -> Set[str]:
    """Extracts tokens that appear in the prompt text."""
    encoding = tiktoken.encoding_for_model(model_name)
    token_ids = encoding.encode(prompt_text)
    excluded_tokens = set()
    for token_id in token_ids:
        excluded_tokens.add(encoding.decode([token_id]))
    return excluded_tokens


def tokens_substitution_many2one(
    original_symbols: Set[str], seed: int = 42, verbose: bool = True
) -> Dict[str, Tuple[int, str]]:
    """
    Create a mapping from original symbols to [a-z] tokens.
    """
    random.seed(seed)
    original_symbols = list(original_symbols)
    lowercase_alphabets = [chr(i) for i in range(ord("a"), ord("z") + 1)]
    sampled_lowercase_alphabets = random.sample(
        lowercase_alphabets, len(original_symbols)
    )
    mapping = {}
    for symbol, sampled_alphabet in zip(original_symbols, sampled_lowercase_alphabets):
        mapping[symbol] = (-1, sampled_alphabet)
    return mapping


def create_token_substitution(
    original_symbols: Set[str],
    prompt_before_exemplars: str,
    prompt_after_exemplars: str,
    substitution_strategy: str,
    model_name: str = "gpt-4o",
    seed: int = 42,
    verbose: bool = True,
) -> Dict[str, Tuple[int, str]]:
    """
    Create a mapping from original symbols to new tokens.

    Args:
        original_symbols: Set of original symbols to be replaced
        prompt_before_exemplars: prompt before the examples
        prompt_after_exemplars: prompt after the examples
        substitution_strategy: The strategy to use for token substitution (many2one, one2one, one2many)
        model_name: Name of the model to use for tokenization
        seed: Random seed for reproducibility
        verbose: Whether to print progress information

    Returns:
        Dictionary mapping original tokens to (token_id (= -1 if not available), new_token) tuples
    """
    if substitution_strategy == "many2one":
        return tokens_substitution_many2one(original_symbols, seed, verbose)
    if substitution_strategy not in [
        "one2one",
        "one2many",
        "one2many-sym",
        "one2one-sym",
    ]:
        raise ValueError(f"Invalid substitution strategy: {substitution_strategy}")

    random.seed(seed)
    # Get all space-prefixed tokens
    available_tokens = get_space_prefixed_tokens(model_name, substitution_strategy)
    if verbose:
        print(
            f"Found {len(available_tokens)} tokens that start with space and end with one or more lowercase alphabetic characters"
        )

    # Get excluded tokens from prompt text
    excluded_tokens = get_excluded_tokens(prompt_before_exemplars, model_name)
    excluded_tokens.update(get_excluded_tokens(prompt_after_exemplars, model_name))
    if verbose:
        print(f"\nExcluded tokens from prompt (showing · for space):")
        for token in sorted(excluded_tokens):
            print(f"  {token.replace(' ', '·')}")

    # Filter out excluded tokens from available tokens
    valid_tokens = [
        (tid, tok) for tid, tok in available_tokens if tok not in excluded_tokens
    ]

    if verbose:
        print(
            f"\nHave {len(valid_tokens)} tokens available after excluding prompt tokens"
        )

    if len(valid_tokens) < len(original_symbols):
        raise ValueError(
            f"Not enough valid tokens ({len(valid_tokens)}) "
            f"to replace original tokens ({len(original_symbols)})"
        )

    if substitution_strategy == "one2one" or substitution_strategy == "one2one-sym":
        replacement_tokens = random.sample(valid_tokens, len(original_symbols))
        return dict(zip(sorted(original_symbols), sorted(replacement_tokens)))
    elif substitution_strategy == "one2many" or substitution_strategy == "one2many-sym":
        # add TIMESTEP_TOKEN to the replacement tokens
        replacement_tokens = random.sample(valid_tokens, len(original_symbols) + 1)
        original_symbols.add(TIMESTEP_TOKEN)
        mapping = dict(zip(sorted(original_symbols), sorted(replacement_tokens)))
        return mapping


def apply_token_mapping(
    text: str,
    substitution_strategy: str,
    mapping: Dict[str, Tuple[int, str]],
    model_name: str = "gpt-4o",
    one2many_timestep: int | None = None,
) -> str:
    """
    Apply token substitution to text.

    Args:
        text: Text to modify
        substitution_strategy: The strategy to use for token substitution (many2one, one2one, one2many)
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples
        model_name: The name of the model to use for tokenization
    Returns:
        Modified text with tokens replaced
    """
    if substitution_strategy == "many2one":
        result = "".join(text.split())
        for orig, (_, new) in mapping.items():
            result = result.replace(orig.strip(), new.strip())
        return result
    elif substitution_strategy == "one2one" or substitution_strategy == "one2one-sym":
        result = "".join(text.split())
        for orig, (_, new) in mapping.items():
            result = result.replace(orig, new)
        return (
            result.lstrip()
        )  # trim the first space because there's already one in the prompt
    elif substitution_strategy == "one2many" or substitution_strategy == "one2many-sym":
        if one2many_timestep is None:
            raise ValueError("one2many_timestep must be provided")
        result = "".join(text.split())
        for orig, (_, new) in mapping.items():
            new += "".join(
                [mapping[TIMESTEP_TOKEN][1] for _ in range(one2many_timestep - 1)]
            )
            result = result.replace(orig, new)
        return (
            result.lstrip()
        )  # trim the first space because there's already one in the prompt

    else:
        raise ValueError(f"Invalid substitution strategy: {substitution_strategy}")


def create_formal_language_prompt(
    prompt_before_exemplars: str,
    prompt_after_exemplars: str,
    training_examples: List[Tuple[int, str]],
    reasoning: bool = False,
) -> str:
    """
    Create a complete prompt for formal language learning task.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples

    Returns:
        Complete formatted prompt text
    """
    # Format training examples
    examples_text = "\n".join(f"{label}: {text}" for label, text in training_examples)

    # Combine all parts
    prompt_text = (
        f"{prompt_before_exemplars}\n\n{examples_text}\n\n{prompt_after_exemplars}"
    )
    prompt_text += "\nstring: "
    return prompt_text


def get_unique_symbols(
    examples: List[Tuple[int, str]], test_strings: List[str]
) -> Set[str]:
    """
    Extract unique symbols from all examples and test string.

    Args:
        examples: List of (label, text) tuples for examples
        test_strings: List of test strings to classify

    Returns:
        Set of unique symbols
    """
    # Combine all strings
    all_strings = [text for _, text in examples if text.strip()] + test_strings

    # Split strings and get unique symbols
    unique_symbols = set()
    for string in all_strings:
        unique_symbols.update(string.split())

    return unique_symbols



def create_modified_prompt(
    prompt_before_exemplars: str,
    prompt_after_exemplars: str,
    training_examples: List[Tuple[int, str]],
    substitution_strategy: str,
    mapping: Dict[str, Tuple[int, str]],
    model_name: str = "gpt-4o",
    one2many_timestep: int | None = None,
) -> str:
    """
    Create a complete prompt with token substitutions applied.

    Args:
        prompt_before_exemplars: The base prompt explaining the task format
        prompt_after_exemplars: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples
        substitution_strategy: The strategy to use for token substitution
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples
        model_name: The name of the model to use for tokenization

    Returns:
        Complete formatted prompt text with substituted tokens
    """
    # Apply substitution to training examples
    modified_examples = []
    for label, text in training_examples:
        if text.strip():  # Skip empty strings
            modified_text = apply_token_mapping(
                text, substitution_strategy, mapping, model_name, one2many_timestep
            )
            modified_examples.append((label, modified_text))
        else:
            modified_examples.append((label, text))  # Keep empty strings as is

    # illustration
    original_example = [
        elem for elem in training_examples if len(elem[1].split()) < 10
    ][0]
    good_example = (
        original_example[0],
        apply_token_mapping(
            original_example[1],
            substitution_strategy,
            mapping,
            model_name,
            one2many_timestep,
        ),
    )
    tokenizer = tiktoken.encoding_for_model(model_name)
    tokenized = text2tokens(f"{good_example[0]}: {good_example[1]}", tokenizer)

    # Get the original example that corresponds to the good example

    logger.info(
        f'Modified examples: "{original_example[0]}: {original_example[1]}" -(token replacement)-> '
        f'"{good_example[0]}: {good_example[1]}" -(tokenize)-> {tokenized}'
    )
    # Create the complete modified prompt
    return create_formal_language_prompt(
        prompt_before_exemplars, prompt_after_exemplars, modified_examples
    )


def count_tokens(text: str, model_name: str = "gpt-4o") -> int:
    """
    Count the number of tokens in a text using the specified model's tokenizer.

    Args:
        text: The text to tokenize
        model_name: Name of the model to use for tokenization

    Returns:
        Number of tokens in the text
    """
    encoding = tiktoken.encoding_for_model(model_name)
    return len(encoding.encode(text))


def main():
    parser = argparse.ArgumentParser(
        description="Create and apply token substitutions for formal language examples"
    )
    parser.add_argument(
        "--model", default="gpt-4o", help="Model name to use for tokenization"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
    parser.add_argument(
        "--save_mapping", type=str, help="Save mapping to specified JSON file"
    )
    parser.add_argument(
        "--load_mapping", type=str, help="Load mapping from specified JSON file"
    )
    parser.add_argument(
        "--output_prompt", type=str, help="Save modified prompt to specified text file"
    )
    parser.add_argument(
        "--count_tokens", type=str, help="Count tokens in the specified file"
    )
    parser.add_argument(
        "--demo", action="store_true", help="Run demonstration with sample data"
    )

    args = parser.parse_args()

    if args.count_tokens:
        try:
            with open(args.count_tokens, "r") as f:
                text = f.read()
            num_tokens = count_tokens(text, args.model)
            print(f"Number of tokens in {args.count_tokens}: {num_tokens}")
            return
        except Exception as e:
            print(f"Error counting tokens: {e}")
            return

    if args.demo:
        # Example usage with formal language task
        base_prompt = """Here are some positive and negative examples of strings in a formal language. Each example is on a separate line and follows the format "label: string", where label is 0 or 1 (meaning negative or positive, respectively), and string is a sequence of symbols separated by spaces. A blank line indicates the end of the examples."""

        training_examples = [
            (0, "0 1 0 1 1 0 1"),
            (1, "1 0 1"),
            (0, "0 1 1"),
            (1, "1 1 1"),
            (0, ""),
            (0, "0 0 0 0"),
        ]

        test_string = "1 0 1 1 1"

        try:
            # Create complete prompt
            prompt_text = create_formal_language_prompt(
                base_prompt, training_examples, test_string
            )

            # Get unique tokens to substitute
            unique_tokens = get_unique_symbols(training_examples, test_string)

            if args.load_mapping:
                # Load existing mapping
                mapping = load_mapping(args.load_mapping)
                print("Loaded mapping:")
            else:
                # Create new mapping
                mapping = create_token_substitution(
                    original_symbols=unique_tokens,
                    prompt_text=prompt_text,
                    model_name=args.model,
                    seed=args.seed,
                )

            # print_token_mapping(mapping)

            # Apply substitution to test string
            modified_test = apply_token_mapping(test_string, mapping)
            print(f"\nOriginal test string: {test_string}")
            print(f"Modified test string: {modified_test}")

            # Apply substitution to training examples
            print("\nTraining examples substitution:")
            for label, text in training_examples:
                if text.strip():  # Skip empty strings
                    modified = apply_token_mapping(text, mapping)
                    print(f"Original ({label}): {text}")
                    print(f"Modified ({label}): {modified}\n")

            # Create and display complete modified prompt
            modified_prompt = create_modified_prompt(
                base_prompt, training_examples, test_string, mapping
            )
            print("\nComplete modified prompt:")
            print("-" * 60)
            print(modified_prompt)
            print("-" * 60)

            # Save mapping if requested
            if args.save_mapping:
                save_mapping(mapping, args.save_mapping)
                print(f"Mapping saved to {args.save_mapping}")

            # Save modified prompt if requested
            if args.output_prompt:
                with open(args.output_prompt, "w") as f:
                    f.write(modified_prompt)
                print(f"Modified prompt saved to {args.output_prompt}")

        except ValueError as e:
            print(f"\nError in token substitution: {e}")

    else:
        parser.print_help()


if __name__ == "__main__":
    main()
