import torch
import random
import requests
import json
from pathlib import Path
import os
import string

# def get_sample_problems(task_type="math"):
#     """
#     Return a list of sample problems based on the task type.
    
#     Args:
#         task_type (str): Type of problems to return. Options: "math", "text", "logic"
    
#     Returns:
#         list: A list of problem strings
#     """
#     if task_type == "math":
#         return [
#             "If John has 5 apples and eats 2, how many apples does he have left?",
#             "A store sells notebooks for $4 each. If you buy 3 notebooks and pay with a $20 bill, how much change will you receive?",
#             "If a train travels 60 miles per hour for 2.5 hours, how far does it travel?",
#             "What is the area of a rectangle with length 8 cm and width 5 cm?",
#             "If a pizza costs $12 and is cut into 8 equal slices, how much does each slice cost?",
#             "A car travels at 65 miles per hour. How far will it travel in 3.5 hours?",
#             "If 3x + 7 = 22, what is the value of x?",
#             "A box contains 24 red marbles and 36 blue marbles. What fraction of the marbles are red?",
#             "If a shirt costs $25 and is on sale for 20% off, what is the sale price?",
#             "The sum of three consecutive integers is 72. What is the smallest of these integers?"
#         ]
#     elif task_type == "text":
#         return [
#             "Summarize the main benefits of regular exercise for overall health.",
#             "Explain how photosynthesis works in plants.",
#             "Describe the water cycle and its importance for Earth's ecosystems.",
#             "What are the main causes of climate change?",
#             "Explain the difference between renewable and non-renewable energy sources.",
#             "Describe the process of digestion in humans.",
#             "What are the key features of a democratic government?",
#             "Explain how the internet works in simple terms.",
#             "Describe the structure and function of DNA.",
#             "What were the main causes and effects of the Industrial Revolution?"
#         ]
#     elif task_type == "logic":
#         return [
#             "If all A are B, and all B are C, what can we conclude about A and C?",
#             "If it's not true that 'if it's raining, then the ground is wet', what can we conclude?",
#             "If either the butler or the maid is guilty, and the butler has an alibi, who is guilty?",
#             "All birds can fly. Penguins are birds. Can penguins fly? Explain the logical issue.",
#             "If I always bring an umbrella when it rains, and I didn't bring an umbrella today, what can we conclude?",
#             "If P implies Q, and Q is false, what can we say about P?",
#             "If no A are B, and some C are B, what can we conclude about A and C?",
#             "If either X or Y is true, and X is false, what must be true?",
#             "If whenever I study, I pass the test, and I failed the test, what can we conclude?",
#             "If all dogs bark, and Fido doesn't bark, what can we conclude about Fido?"
#         ]
#     else:
#         raise ValueError(f"Unknown task type: {task_type}. Choose from 'math', 'text', or 'logic'.")

def load_gsm8k_problems(num_samples=10, cache_dir="data_cache"):
    """
    Load problems from the GSM8K dataset.
    
    Args:
        num_samples (int): Number of problems to load
        cache_dir (str): Directory to cache the dataset
    
    Returns:
        list: A list of problem strings
    """
    cache_path = Path(cache_dir) / "gsm8k_samples.json"
    
    # Create cache directory if it doesn't exist
    os.makedirs(cache_dir, exist_ok=True)
    
    # Check if cached data exists
    if cache_path.exists():
        with open(cache_path, 'r') as f:
            problems = json.load(f)
            return problems[:num_samples]
    
    # If not cached, download a sample
    try:
        url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
        response = requests.get(url)
        response.raise_for_status()
        
        # Parse JSONL
        lines = response.text.strip().split('\n')
        problems = []
        
        for line in lines[:100]:  # Get first 100 to sample from
            data = json.loads(line)
            problems.append(data['question'])
        
        # Randomly sample
        sampled_problems = random.sample(problems, min(num_samples, len(problems)))
        
        # Cache the data
        with open(cache_path, 'w') as f:
            json.dump(sampled_problems, f)
        
        return sampled_problems
    
    except Exception as e:
        print(f"Error loading GSM8K dataset: {e}")
        # Fall back to sample problems
        return get_sample_problems("math")[:num_samples]

def load_mmlu_problems(subject="logical_fallacies", num_samples=10, cache_dir="data_cache"):
    """
    Load problems from the MMLU dataset using Hugging Face datasets.
    
    Args:
        subject (str): Subject to load problems from
        num_samples (int): Number of samples to load
        cache_dir (str): Directory to cache the dataset
    
    Returns:
        list: A list of problem strings
    """
    # Map our simplified names to actual MMLU subject names
    subject_map = {
        "logic": "logical_fallacies",
        "text": "high_school_european_history",
        "math": "high_school_mathematics",
        "mmlu_logic": "logical_fallacies",
        "mmlu_reading": "high_school_european_history"
    }
    
    # Use mapped subject if available
    actual_subject = subject_map.get(subject, subject)
    
    try:
        # Import datasets here to avoid making it a required dependency
        from datasets import load_dataset
        
        print(f"Loading MMLU dataset for subject: {actual_subject}")
        
        # Load the dataset from Hugging Face
        dataset = load_dataset("cais/mmlu", actual_subject, split="test", cache_dir=cache_dir)
        
        # Convert to our format
        problems = []
        for item in dataset:
            question = item["question"]
            choices = item["choices"]
            
            # Format as a problem string
            problem = f"{question}\n"
            for i, choice in enumerate(choices):
                problem += f"{chr(65+i)}. {choice}\n"  # A., B., C., D.
            
            problems.append(problem)
        
        # Sample if needed
        if len(problems) > num_samples:
            import random
            problems = random.sample(problems, num_samples)
        
        print(f"Loaded {len(problems)} problems from MMLU {actual_subject}")
        return problems
        
    except Exception as e:
        print(f"Error loading MMLU dataset: {e}")
        # Fall back to sample problems
        return #get_sample_problems(subject.replace("mmlu_", ""))[:num_samples]

def prepare_batch(samples, tokenizer, max_length=256):
    """
    Prepare a batch of samples for alignment metrics calculation.
    Ensures tensors are placed on the correct device (CUDA if available).
    """
    # Tokenize samples
    encodings = tokenizer(
        samples,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    # Move tensors to CUDA if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch = {k: v.to(device) for k, v in encodings.items()}
    
    # Create labels for causal LM
    batch["labels"] = batch["input_ids"].clone()
    
    return batch

def generate_pure_nonsense(min_length=20, max_length=100, num_samples=10):
    """Generate completely nonsensical text with strange characters and symbols.
    
    Args:
        min_length (int): Minimum length of each nonsense sequence
        max_length (int): Maximum length of each nonsense sequence
        num_samples (int): Number of nonsense samples to generate
        
    Returns:
        list: A list of nonsensical strings
    """
    # Collection of character sets to draw from
    char_sets = [
        string.ascii_letters,           # Letters
        string.digits,                  # Numbers
        string.punctuation,             # Punctuation
        ''.join(chr(i) for i in range(0x0370, 0x03FF)),  # Greek letters
        ''.join(chr(i) for i in range(0x0400, 0x04FF)),  # Cyrillic
        #''.join(chr(i) for i in range(0x2500, 0x257F)),  # Box drawing
        #''.join(chr(i) for i in range(0x2600, 0x26FF)),  # Miscellaneous symbols
        ''.join([chr(i) for i in range(0x1F600, 0x1F64F) if chr(i).isprintable()])  # Printable emoji
    ]#
    
    nonsense_data = []
    
    for _ in range(num_samples):
        # Combine a subset of character sets for this particular nonsense text
        selected_sets = random.sample(char_sets, k=random.randint(3, len(char_sets)))
        chars = ''.join(selected_sets)
        
        # Generate random "tokens" (groups of random characters)
        tokens = []
        length = random.randint(min_length, max_length)
        
        for _ in range(length):
            # Generate a random token of varying length (1-10 characters)
            token_length = random.randint(1, 10)
            token = ''.join(random.choice(chars) for _ in range(token_length))
            tokens.append(token)
        
        # Join with spaces and random separators
        separators = [' ', '\n', '\t', '   ', ' ⁄ ', ' | ', ' ~ ', ' § ']
        text = random.choice(separators).join(tokens)
        
        # Randomly insert special characters
        special_chars = '☢☣⚠⌘☯✿⚛⚕⚚'
        for _ in range(random.randint(1, 5)):
            if len(text) > 2:
                pos = random.randint(0, len(text) - 1)
                char = random.choice(special_chars)
                text = text[:pos] + char + text[pos:]
        
        nonsense_data.append(text)
    
    return nonsense_data

def load_humaneval_problems(num_samples=10, cache_dir="data_cache"):
    """
    Load problems from the HumanEval dataset (Python programming tasks).
    
    Args:
        num_samples (int): Number of problems to load
        cache_dir (str): Directory to cache the dataset
    
    Returns:
        list: A list of programming problem strings
    """
    cache_path = Path(cache_dir) / "humaneval_samples.json"
    
    # Create cache directory if it doesn't exist
    os.makedirs(cache_dir, exist_ok=True)
    
    # Check if cached data exists
    if cache_path.exists():
        with open(cache_path, 'r') as f:
            problems = json.load(f)
            return problems[:num_samples]
    
    try:
        # Import datasets here to avoid making it a required dependency
        from datasets import load_dataset
        
        print("Loading HumanEval dataset...")
        
        # Load the dataset from Hugging Face
        dataset = load_dataset("openai_humaneval", split="test", cache_dir=cache_dir)
        
        # Convert to our format
        problems = []
        for item in dataset:
            # Extract the prompt and combine with function signature
            prompt = item["prompt"]
            
            # Format as a problem string
            problem = f"# Write a Python function\n{prompt}"
            problems.append(problem)
        
        # Sample if needed
        if len(problems) > num_samples:
            sampled_problems = random.sample(problems, num_samples)
        else:
            sampled_problems = problems
        
        # Cache the data
        with open(cache_path, 'w') as f:
            json.dump(sampled_problems, f)
        
        print(f"Loaded {len(sampled_problems)} problems from HumanEval")
        return sampled_problems
        
    except Exception as e:
        print(f"Error loading HumanEval dataset: {e}")
        # Fallback to a few hand-crafted coding examples
        fallback_problems = [
            "# Write a function to find the maximum value in a list\ndef find_max(numbers):",
            "# Write a function to check if a string is a palindrome\ndef is_palindrome(text):",
            "# Write a function to calculate the factorial of a number\ndef factorial(n):",
            "# Write a function to sort a list of integers\ndef custom_sort(numbers):",
            "# Write a function to check if a number is prime\ndef is_prime(n):"
        ]
        return fallback_problems[:num_samples]

def generate_random_vocab_tokens(tokenizer, min_length=20, max_length=100, num_samples=10):
    """Generate text using random tokens from the model's vocabulary.
    
    Args:
        tokenizer: The tokenizer to use for vocabulary and decoding
        min_length (int): Minimum length of each sequence in tokens
        max_length (int): Maximum length of each sequence in tokens
        num_samples (int): Number of nonsense samples to generate
        
    Returns:
        list: A list of nonsensical strings made of actual tokens
    """
    # Get the vocabulary size
    vocab_size = tokenizer.vocab_size
    print(f"Generating random token sequences from vocabulary size: {vocab_size}")
    
    # Generate random sequences
    random_sequences = []
    
    for _ in range(num_samples):
        # Choose a random length
        seq_length = random.randint(min_length, max_length)
        
        # Generate random token IDs (avoiding special tokens at the beginning)
        # Most special tokens are at the beginning or end of vocab
        token_ids = [random.randint(5, vocab_size-5) for _ in range(seq_length)]
        
        # Convert token IDs to a tensor
        token_tensor = torch.tensor([token_ids])
        
        # Decode the random tokens to text
        random_text = tokenizer.decode(token_tensor[0], skip_special_tokens=True)
        
        # Sometimes the decoder consolidates tokens, so we might get shorter sequences than expected
        # If we got an empty string, try again with different tokens
        if not random_text.strip():
            # Try again with different range of IDs
            token_ids = [random.randint(100, vocab_size-100) for _ in range(seq_length)]
            token_tensor = torch.tensor([token_ids])
            random_text = tokenizer.decode(token_tensor[0], skip_special_tokens=True)
            
        # Add to our collection
        random_sequences.append(random_text)
    
    return random_sequences

def get_dataset(dataset_name="math", num_samples=10, tokenizer=None):
    """
    Get a dataset by name.
    
    Args:
        dataset_name (str): Name of the dataset to load
        num_samples (int): Number of samples to load
        tokenizer: Tokenizer to use for generating random tokens (needed for "random_tokens" dataset)
    
    Returns:
        list: A list of problem strings
    """
    if dataset_name == "gsm8k":
        return load_gsm8k_problems(num_samples)
    elif dataset_name == "mmlu_logic":
        return load_mmlu_problems("logical_fallacies", num_samples)
    elif dataset_name == "mmlu_reading":
        return load_mmlu_problems("high_school_european_history", num_samples)
    elif dataset_name == "mmlu_math":
        return load_mmlu_problems("high_school_mathematics", num_samples)
    elif dataset_name == "nonsense":
        # Keep the old nonsense generator for backward compatibility
        return generate_pure_nonsense(min_length=20, max_length=100, num_samples=num_samples)
    elif dataset_name == "random_tokens":
        # New dataset of random tokens from the model's vocabulary
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided for 'random_tokens' dataset")
        return generate_random_vocab_tokens(tokenizer, min_length=20, max_length=100, num_samples=num_samples)
    elif dataset_name == "code":
        return load_humaneval_problems(num_samples)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
