import tiktoken
import random
import argparse
from typing import List, Tuple, Dict, Set
import json
from pathlib import Path
import transformers
from src.deepseek.utils import ALPHABET
from src.utils import save_mapping, load_mapping, print_token_mapping


def get_space_prefixed_tokens(
    tokenizer_path, model_name: str = "gpt-3.5-turbo"
) -> List[Tuple[int, str]]:
    """Get all tokens that start with a space from the model's tokenizer."""
    all_tokens = []
    if model_name == "deepseek-chat" or model_name == "deepseek-reasoner":
        try:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                tokenizer_path, trust_remote_code=True
            )
            for token_id in range(tokenizer.vocab_size):
                try:
                    token_bytes = tokenizer.decode([token_id])
                    if (
                        token_bytes.startswith(" ")
                        and len(token_bytes) == 2
                        and token_bytes != " \n"
                    ):
                        all_tokens.append((token_id, token_bytes))
                except:
                    continue
        except Exception as e:
            print(f"Error loading tokenizer vocab: {e}")
        return all_tokens
    else:
        raise ValueError(f"Model name {model_name} not supported")


def get_excluded_tokens(
    prompt_text: str, tokenizer_path, model_name: str = "deepseek-chat"
) -> Set[str]:
    """Extract space-prefixed tokens from prompt text using the tokenizer."""
    excluded_tokens = set()
    if model_name == "deepseek-chat" or model_name == "deepseek-reasoner":
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            tokenizer_path, trust_remote_code=True
        )
        token_ids = tokenizer.encode(prompt_text)
        for token_id in token_ids:
            try:
                token = tokenizer.decode(token_id)
                if token.startswith(" "):
                    excluded_tokens.add(token)
            except Exception as e:
                print(f"Error decoding token_id {token_id}: {e}")
    else:
        raise ValueError(f"Model name {model_name} not supported")
    return excluded_tokens


def filter_alphabet(tokens: List[Tuple[int, str]]) -> List[Tuple[int, str]]:
    """Filter tokens to only include those that are in the ALPHABET."""
    filtered_tokens = []
    for token_id, token_string in tokens:
        if all(char in ALPHABET for char in token_string.strip()):
            filtered_tokens.append((token_id, token_string))
    return filtered_tokens


def create_token_substitution(
    encoding_technique: str,
    symbols: Set[str],
    prompt_text: str,
    model_name: str = "deepseek-chat",
    tokenizer_path: str = None,
    seed: int = 42,
    verbose: bool = True,
) -> Dict[str, Tuple[int, str]]:
    """
    Create a mapping from original tokens to new tokens.

    Args:
        original_tokens: Set of original tokens to be replaced
        prompt_text: Text to extract excluded tokens from
        model_name: Name of the model to use for tokenization
        seed: Random seed for reproducibility
        verbose: Whether to print progress information

    Returns:
        Dictionary mapping original tokens to (token_id, new_token) tuples
    """
    random.seed(seed)
    if encoding_technique == "none":
        return None
    elif encoding_technique == "many_to_one":
        replacement_tokens = random.sample(ALPHABET, len(symbols))
        replacement_tokens = [(None, token) for token in replacement_tokens]
        # Create mapping
        mapping = dict(zip(sorted(symbols), sorted(replacement_tokens)))
    else:
        # Get all space-prefixed tokens
        available_tokens = get_space_prefixed_tokens(tokenizer_path, model_name)
        available_tokens = filter_alphabet(available_tokens)

        # Get excluded tokens from prompt text
        excluded_tokens = get_excluded_tokens(
            prompt_text, tokenizer_path=tokenizer_path, model_name=model_name
        )
        # Filter out excluded tokens from available tokens
        valid_tokens = [
            (tid, tok) for tid, tok in available_tokens if tok not in excluded_tokens
        ]

        if verbose:
            print(f"Found {len(available_tokens)} tokens that start with space")
            print(f"\nExcluded tokens from prompt (showing · for space):")
            for token in sorted(excluded_tokens):
                print(f"  {token.replace(' ', '·')}")
            print(
                f"\nHave {len(valid_tokens)} tokens available after excluding prompt tokens"
            )

        # check if the number of valid tokens is less than the number of symbols
        if len(valid_tokens) < len(symbols):
            raise ValueError(
                f"Not enough valid tokens ({len(valid_tokens)}) "
                f"to replace original tokens ({len(symbols)})"
            )
        if encoding_technique == "one_to_many":
            # Sample len(symbols) + 1 unique tokens
            replacement_tokens = random.sample(valid_tokens, len(symbols) + 1)

            # Choose one token for the space symbol
            space_token = replacement_tokens[-1]

            # Assign the rest to the regular symbols
            main_symbols = sorted(symbols)
            main_replacements = replacement_tokens[:-1]  # Keep original order
            mapping = dict(zip(sorted(main_symbols), sorted(main_replacements)))

            # Add the space mapping
            mapping[" "] = space_token
        else:
            # Sample replacement tokens
            replacement_tokens = random.sample(valid_tokens, len(symbols))
            # Create mapping
            mapping = dict(zip(sorted(symbols), sorted(replacement_tokens)))
    return mapping


def apply_token_mapping(
    text: str,
    mapping: Dict[str, Tuple[int, str]],
    strip_spaces: bool = True,
    encoding_technique: str = "none",
    k: int = 2,
) -> str:
    """
    Apply token substitution to text.

    Args:
        text: Text to modify
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples
        strip_spaces: Whether to strip spaces from tokens before replacement
        k: Number of times to repeat the replacement token

    Returns:
        Modified text with tokens replaced
    """

    result = text.strip()
    for orig, (_, new) in mapping.items():
        if orig == " ":
            continue
        if encoding_technique == "one_to_many":
            # If one-to-many encoding, add spaces between tokens
            replacement = (new.strip()) + (mapping[" "][1]) * (k - 1)
        else:
            replacement = new.strip()
        target = orig.strip() if strip_spaces else orig
        result = result.replace(target, replacement)

    if encoding_technique == "many_to_one":
        # If one-to-many encoding, add spaces between tokens
        result = result.replace(" ", "")

    return result


def create_formal_language_prompt(
    prompt_template: str, training_examples: List[Tuple[int, str]]
) -> str:
    """
    Create a complete prompt for formal language learning task.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples

    Returns:
        Complete formatted prompt text
    """
    # Format training examples
    examples_text = "\n".join(f"{label}: {text}" for text, label in training_examples)
    # Combine all parts
    if prompt_template == "zsr_prompt":
        base_prompt_path = Path("prompts/zsr_prompt_template.txt")
    elif prompt_template == "io_prompt":
        base_prompt_path = Path("prompts/io_prompt_template.txt")
    else:
        base_prompt_path = Path(f"./src/deepseek/prompts_tuning/{prompt_template}.txt")
    try:
        with open(base_prompt_path, "r", encoding="utf-8") as file:
            base_prompt = file.read()
    except FileNotFoundError:
        raise ValueError(f"Base prompt file {base_prompt_path} not found")

    prompt_text = base_prompt.replace("exemplars", examples_text)

    return prompt_text


def get_unique_symbols(
    examples: List[Tuple[int, str]], test_strings: List[str]
) -> Set[str]:
    """
    Extract unique tokens from all examples and test string.

    Args:
        examples: List of (label, text) tuples for examples
        test_strings: List of test strings to classify

    Returns:
        Set of unique tokens
    """
    # Combine all strings
    all_strings = [text for text, _ in examples if text.strip()] + test_strings

    # Split strings and get unique tokens
    unique_symbols = set()
    for string in all_strings:
        unique_symbols.update(string.split())

    return unique_symbols




def create_modified_prompt(
    prompt_template: str,
    training_examples: List[Tuple[int, str]],
    mapping: Dict[str, Tuple[int, str]],
    encoding_technique: str = "none",
    k: int = 2,
) -> str:
    """
    Create a complete prompt with token substitutions applied.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples

    Returns:
        Complete formatted prompt text with substituted tokens
    """
    # Apply substitution to training examples
    modified_examples = []
    if (
        encoding_technique == "one_to_one"
        or encoding_technique == "one_to_many"
        or encoding_technique == "many_to_one"
    ):
        for text, label in training_examples:
            if text.strip():  # Skip empty strings
                modified_text = apply_token_mapping(
                    text, mapping, encoding_technique=encoding_technique, k=k
                )
                modified_examples.append((modified_text, label))
            else:
                modified_examples.append((text, label))  # Keep empty strings as is
    else:
        raise ValueError(f"Encoding technique {encoding_technique} not supported")

    # Create the complete modified prompt
    return create_formal_language_prompt(prompt_template, modified_examples)


def count_tokens(text: str, model_name: str = "gpt-3.5-turbo") -> int:
    """
    Count the number of tokens in a text using the specified model's tokenizer.

    Args:
        text: The text to tokenize
        model_name: Name of the model to use for tokenization

    Returns:
        Number of tokens in the text
    """
    encoding = tiktoken.encoding_for_model(model_name)
    return len(encoding.encode(text))
