from typing import List, Tuple, Set, Dict
import json
import os 
import requests
from dotenv import load_dotenv
import re
import sys

ALPHABET = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"]

def last_binary_digit(s):
    try:
        for char in reversed(s):
            if char in ('0', '1'):
                return char
        return -1  # or raise an error, or return a default value
    except Exception as e:
        print(f"Error: {e}")
        return None

def first_binary_digit(s):
    try:
        for char in s:
            if char in ('0', '1'):
                return char
        return -1 
    except Exception as e:
        print(f"Error: {e}")
        return None

def output_decoder(output, prompt_template):
    """Decode the model output based on the prompt template."""
    if prompt_template == "io_prompt":
        try:
            return int(output.replace(".", ""))
        except ValueError:
            print(f"Error decoding output: {output}")
            return -1
    elif prompt_template == "zsr_prompt":
        return int(last_binary_digit(output))
    else:
        raise ValueError(f"Unknown prompt template: {prompt_template}")
        
def read_prompt_template(prompt_template_path):
    """Read the prompt template from a file."""
    try:
        with open(prompt_template_path, 'r') as file:
            prompt_template = file.read()
        return prompt_template
    except FileNotFoundError:
        print(f"Prompt template file not found: {prompt_template_path}")
        sys.exit(1)
    except Exception as e:
        print(f"Error reading prompt template: {e}")
        sys.exit(1)

def get_unique_symbols(examples: List[Tuple[int, str]], test_strings: List[str]) -> Set[str]:
    """
    Extract unique tokens from all examples and test string.

    Args:
        examples: List of (label, text) tuples for examples
        test_strings: List of test strings to classify

    Returns:
        Set of unique tokens
    """
    # Combine all strings
    all_strings = [text for text, _ in examples if text.strip()] + test_strings

    # Split strings and get unique tokens
    unique_symbols = set()
    for string in all_strings:
        unique_symbols.update(string.split())

    return unique_symbols


def print_label_stats(data: List[Tuple[str, int]], name: str = ""):
    """Print statistics about label distribution in the data."""
    total = len(data)
    if total == 0:
        return

    labels = [label for _, label in data]
    num_label_0 = sum(1 for label in labels if label == 0)
    num_label_1 = sum(1 for label in labels if label == 1)

    print(f"\n{name} statistics:")
    print(f"Total examples: {total}")
    print(f"Label 0: {num_label_0} ({num_label_0/total*100:.1f}%)")
    print(f"Label 1: {num_label_1} ({num_label_1/total*100:.1f}%)")

def create_modified_prompt(prompt_template: str, training_examples: List[Tuple[int, str]], mapping: Dict[str, Tuple[int, str]], encoding_technique: str="none", k: int=1) -> str:
    """
    Create a complete prompt with token substitutions applied.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples
        mapping: Dictionary mapping original tokens to (token_id, new_token) tuples

    Returns:
        Complete formatted prompt text with substituted tokens
    """
    # Apply substitution to training examples
    modified_examples = []
    if encoding_technique == "one_to_one" or encoding_technique == "one_to_many" or encoding_technique == "many_to_one":
        for text, label in training_examples:
            if text.strip():  # Skip empty strings
                modified_text = apply_token_mapping(text, mapping, encoding_technique=encoding_technique, k=k)
                modified_examples.append((modified_text, label))
            else:
                modified_examples.append((text, label))  # Keep empty strings as is
    else:
        raise ValueError(f"Encoding technique {encoding_technique} not supported")
    
    # Create the complete modified prompt
    return create_formal_language_prompt(prompt_template, modified_examples)


def create_formal_language_prompt(prompt_template: str, training_examples: List[Tuple[int, str]]) -> str:
    """
    Create a complete prompt for formal language learning task.

    Args:
        base_prompt: The base prompt explaining the task format
        training_examples: List of (label, text) tuples for training examples

    Returns:
        Complete formatted prompt text
    """
    # Format training examples
    examples_text = "\n".join(f"{label}: {text}" for text, label in training_examples)
    # Combine all parts
    if prompt_template == "zsr_prompt":
        base_prompt_path = Path("prompts/zsr_prompt_template.txt")
    elif prompt_template == "io_prompt":
        base_prompt_path = Path("prompts/io_prompt_template.txt")
    try:
        with open(base_prompt_path, "r", encoding="utf-8") as file:
            base_prompt = file.read()    
    except FileNotFoundError:
        raise ValueError(f"Base prompt file {base_prompt_path} not found")

    prompt_text = base_prompt.replace("exemplars", examples_text)

    return prompt_text
