import tiktoken
import random
import argparse
from typing import List, Tuple, Dict, Set
import json
from pathlib import Path
import transformers
from src.utils import ALPHABET
from transformers import AutoTokenizer

def get_space_prefixed_tokens(model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct")-> List[Tuple[int, str]]:
    """Get all tokens that start with a space from the model's tokenizer."""
    all_tokens = []
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    all_tokens = tokenizer.get_vocab()
    #print(f"all_tokens: {all_tokens.items()}")
    #space_prefixed_tokens = [(idx, token.replace('▁', ' ')) for token, idx in all_tokens.items() if token.startswith('▁')]
    space_prefixed_tokens = [(idx, token.replace('_', ' ')) for token, idx in all_tokens.items() if token.startswith('_')]
    print(f"space_prefixed token: {len(space_prefixed_tokens)}")
    return space_prefixed_tokens

def get_excluded_tokens(prompt_text: str, model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct") -> Set[str]:
    """Extract space-prefixed tokens from prompt text using the tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    # Tokenize the prompt text
    tokens = tokenizer.tokenize(prompt_text)
    # Get the token IDs and strings
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    token_strings = [tokenizer.decode([token_id]) for token_id in token_ids]

    # Filter out space-prefixed tokens
    excluded_tokens = set(token for token in token_strings if token.startswith(" "))
    return excluded_tokens

def filter_alphabet(tokens: List[Tuple[int, str]]) -> List[Tuple[int, str]]:
    """Filter tokens to only include those that are in the ALPHABET."""
    filtered_tokens = []
    for token_id, token_string in tokens:
        if all(char in ALPHABET for char in token_string.strip()):
            filtered_tokens.append((token_id, token_string))
        else:
            # print(f"Filtered out token: {token_string} (ID: {token_id})")
            pass
    print(f"Filtered tokens: {len(filtered_tokens)}")
    return filtered_tokens

def create_token_substitution(
    encoding_technique: str,
    symbols: Set[str],
    prompt_text: str,
    model_name: str = "",
    #api_key: str = None,
    seed: int = 42,
    verbose: bool = True,
    
) -> Dict[str, Tuple[int, str]]:
    """
    Create a mapping from original tokens to new tokens.

    Args:
        original_tokens: Set of original tokens to be replaced
        prompt_text: Text to extract excluded tokens from
        model_name: Name of the model to use for tokenization
        seed: Random seed for reproducibility
        verbose: Whether to print progress information

    Returns:
        Dictionary mapping original tokens to (token_id, new_token) tuples
    """
    random.seed(seed)
    if encoding_technique == "none":
        return None
    elif encoding_technique == "many_to_one":
        replacement_tokens = random.sample(ALPHABET, len(symbols))
        replacement_tokens = [(None, token) for token in replacement_tokens]
        # Create mapping
        mapping = dict(zip(sorted(symbols), sorted(replacement_tokens)))
    else:
        verbose = True
        # Get all space-prefixed tokens
        available_tokens = get_space_prefixed_tokens(model_name)
        available_tokens = filter_alphabet(available_tokens)

        # Get excluded tokens from prompt text
        excluded_tokens = get_excluded_tokens(prompt_text, model_name=model_name)#, api_key=api_key)
        # Filter out excluded tokens from available tokens
        valid_tokens = [(tid, tok) for tid, tok in available_tokens
                    if tok not in excluded_tokens]

        if verbose:
            print(f"Found {len(available_tokens)} tokens that start with space")
            print(f"\nExcluded tokens from prompt (showing · for space):")
            for token in sorted(excluded_tokens):
                    print(f"  {token.replace(' ', '·')}")
            print(f"\nHave {len(valid_tokens)} tokens available after excluding prompt tokens")

        # check if the number of valid tokens is less than the number of symbols
        if len(valid_tokens) < len(symbols):
            raise ValueError(
                f"Not enough valid tokens ({len(valid_tokens)}) "
                f"to replace original tokens ({len(symbols)})"
            )
        if encoding_technique == "one_to_many":
            # Sample len(symbols) + 1 unique tokens
            replacement_tokens = random.sample(valid_tokens, len(symbols) + 1)
            
            # Choose one token for the space symbol
            space_token = replacement_tokens[-1]

            # Assign the rest to the regular symbols
            main_symbols = sorted(symbols)
            main_replacements = replacement_tokens[:-1]  # Keep original order
            mapping = dict(zip(sorted(main_symbols), sorted(main_replacements)))
            
            # Add the space mapping
            mapping[" "] = space_token
        else:
            # Sample replacement tokens
            replacement_tokens = random.sample(valid_tokens, len(symbols))
            # Create mapping
            mapping = dict(zip(sorted(symbols), sorted(replacement_tokens)))
    
    print(f"mapping: \n{mapping}")
    return mapping


def apply_token_mapping(text: str, mapping: Dict[str, Tuple[int, str]], strip_spaces: bool = True, encoding_technique:str = "none", k: int = 1) -> str:
    """
    Apply token substitution to text.

    Args:
        text: Text to modify
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples
        strip_spaces: Whether to strip spaces from tokens before replacement
        k: Number of times to repeat the replacement token

    Returns:
        Modified text with tokens replaced
    """
    
    result = text.strip()
    for orig, (_, new) in mapping.items():
        if orig == " ":
            continue
        if encoding_technique == "one_to_many":
            # If one-to-many encoding, add spaces between tokens
            replacement = (new.strip()) + (mapping[' '][1])*(k-1)
        else:
            replacement = (new.strip())
        target = orig.strip() if strip_spaces else orig
        result = result.replace(target, replacement)
    
    if encoding_technique == "many_to_one":
        # If one-to-many encoding, add spaces between tokens
        result = result.replace(" ", "")

    return result

def create_formal_language_prompt(prompt_template: str, training_examples: List[Tuple[int, str]]) -> str:
    """
    Create a complete prompt for formal language learning task.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples

    Returns:
        Complete formatted prompt text
    """
    # Format training examples
    examples_text = "\n".join(f"{label}: {text}" for text, label in training_examples)
    # Combine all parts
    if prompt_template == "zsr_prompt":
        base_prompt_path = Path("prompts/zsr_prompt_template.txt")
    elif prompt_template == "io_prompt":
        base_prompt_path = Path("prompts/io_prompt_template.txt")
    try:
        with open(base_prompt_path, "r", encoding="utf-8") as file:
            base_prompt = file.read()    
    except FileNotFoundError:
        raise ValueError(f"Base prompt file {base_prompt_path} not found")

    prompt_text = base_prompt.replace("exemplars", examples_text)

    return prompt_text

def get_unique_symbols(examples: List[Tuple[int, str]], test_strings: List[str]) -> Set[str]:
    """
    Extract unique tokens from all examples and test string.

    Args:
        examples: List of (label, text) tuples for examples
        test_strings: List of test strings to classify

    Returns:
        Set of unique tokens
    """
    # Combine all strings
    all_strings = [text for text, _ in examples if text.strip()] + test_strings
    print(f"all_strings: {all_strings}")
    # Split strings and get unique tokens
    unique_symbols = set()
    for string in all_strings:
        unique_symbols.update(string.split())

    print(f"unique_symbols: {unique_symbols}")
    return unique_symbols

def save_mapping(mapping: Dict[str, Tuple[int, str]], output_file: str):
    """Save token mapping to a JSON file."""
    serializable_mapping = {
        orig: {"id": tid, "token": tok}
        for orig, (tid, tok) in mapping.items()
    }

    with open(output_file, 'w') as f:
        json.dump(serializable_mapping, f, indent=2)

def load_mapping(input_file: str) -> Dict[str, Tuple[int, str]]:
    """Load token mapping from a JSON file."""
    with open(input_file) as f:
        data = json.load(f)

    return {
        orig: (info["id"], info["token"])
        for orig, info in data.items()
    }

def print_token_mapping(mapping: Dict[str, Tuple[int, str]], title: str = "Token Mapping"):
    """Pretty print the token mapping."""
    print(f"\n{title}:")
    print("Original Token\tNew Token ID\tNew Token (showing · for space)")
    print("-" * 60)
    for orig_token, (new_id, new_token) in sorted(mapping.items()):
        orig_visible = orig_token.replace(' ', '·')
        new_visible = new_token.replace(' ', '·')
        print(f"{orig_visible}\t{new_id}\t{new_visible}")

def create_modified_prompt(prompt_template: str, training_examples: List[Tuple[int, str]], mapping: Dict[str, Tuple[int, str]], encoding_technique: str="none", k: int=1) -> str:
    """
    Create a complete prompt with token substitutions applied.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples

    Returns:
        Complete formatted prompt text with substituted tokens
    """
    # Apply substitution to training examples
    modified_examples = []
    if encoding_technique == "one_to_one" or encoding_technique == "one_to_many" or encoding_technique == "many_to_one":
        for text, label in training_examples:
            if text.strip():  # Skip empty strings
                modified_text = apply_token_mapping(text, mapping, encoding_technique=encoding_technique, k=k)
                modified_examples.append((modified_text, label))
            else:
                modified_examples.append((text, label))  # Keep empty strings as is
    else:
        raise ValueError(f"Encoding technique {encoding_technique} not supported")
    
    # Create the complete modified prompt
    return create_formal_language_prompt(prompt_template, modified_examples)

def count_tokens(text: str, model_name: str = "gpt-3.5-turbo") -> int:
    """
    Count the number of tokens in a text using the specified model's tokenizer.

    Args:
        text: The text to tokenize
        model_name: Name of the model to use for tokenization

    Returns:
        Number of tokens in the text
    """
    encoding = tiktoken.encoding_for_model(model_name)
    return len(encoding.encode(text))
