# data_utils.py

import numpy as np
import torch
import random
from torch.utils.data import DataLoader, Subset
from datasets import load_dataset
from transformers import DataCollatorWithPadding

# Dictionary to hold configurations for different GLUE tasks
# Single-sentence tasks: sst2, cola
# Sentence-pair tasks: mnli, qqp, rte, mrpc
GLUE_TASK_CONFIG = {
    # --- Single Sentence Classification Tasks ---
    'sst2': {
        'keys': ('sentence',), 'num_labels': 2, 'validation_split': 'validation'
    },
    'cola': {
        'keys': ('sentence',), 'num_labels': 2, 'validation_split': 'validation'
    },
    # --- Sentence Pair Classification Tasks ---
    'mnli': {
        'keys': ('premise', 'hypothesis'), 'num_labels': 3, 'validation_split': 'validation_matched'
    },
    'qqp': {
        'keys': ('question1', 'question2'), 'num_labels': 2, 'validation_split': 'validation'
    },
    'rte': {
        'keys': ('sentence1', 'sentence2'), 'num_labels': 2, 'validation_split': 'validation'
    },
    'mrpc': {
        'keys': ('sentence1', 'sentence2'), 'num_labels': 2, 'validation_split': 'validation'
    }
}

def scramble_sentence(sentence: str, percentage: float) -> str:
    """
    Scrambles a given percentage of words in a sentence.
    It selects a subset of words, shuffles them, and places them back.
    """
    if percentage <= 0:
        return sentence
    words = sentence.split()
    num_words = len(words)
    if num_words < 2:
        return sentence # Cannot scramble a single word

    # Determine the number of words to scramble
    num_to_scramble = int(round(num_words * percentage))

    # Get unique indices of words to be scrambled
    indices_to_scramble = sorted(random.sample(range(num_words), num_to_scramble))

    # Get the words at these indices
    words_to_scramble = [words[i] for i in indices_to_scramble]

    # Shuffle only the selected words
    random.shuffle(words_to_scramble)

    # Place the shuffled words back into the original list at the selected indices
    for i, original_index in enumerate(indices_to_scramble):
        words[original_index] = words_to_scramble[i]

    print(f'before: {sentence}')
    print(f'after: {" ".join(words)}')
    exit()
        
    return " ".join(words)

def _get_glue_dataset(task_name, tokenizer, args=None):
    """ Loads and preprocesses a specific GLUE task dataset. """
    if task_name not in GLUE_TASK_CONFIG:
        raise ValueError(f"Task '{task_name}' is not supported for GLUE. Please choose from {list(GLUE_TASK_CONFIG.keys())}")

    config = GLUE_TASK_CONFIG[task_name]
    sentence_keys = config['keys']
    validation_split = config['validation_split']
    num_labels = config['num_labels']

    print(f"Loading dataset for GLUE task: {task_name.upper()}")
    raw_datasets = load_dataset('glue', task_name)

    # +++ START: NEW LOGIC FOR SCRAMBLING +++
    # Apply word scrambling for the task discrepancy experiment on MNLI
    if task_name == 'mnli' and args and args.scramble_percentage > 0:
        scramble_perc = args.scramble_percentage
        print(f"INFO: Applying {scramble_perc*100:.0f}% word scrambling to the 'hypothesis' of MNLI.")
        
        def scramble_examples(examples):
            # The 'hypothesis' key is specific to MNLI
            examples['hypothesis'] = [scramble_sentence(h, scramble_perc) for h in examples['hypothesis']]
            return examples
        
        # Apply the scrambling function to both train and validation splits
        raw_datasets = raw_datasets.map(scramble_examples, batched=True, desc="Scrambling hypotheses")
    # +++ END: NEW LOGIC FOR SCRAMBLING +++

    def preprocess_function(examples):
        args = (examples[key] for key in sentence_keys)
        return tokenizer(*args, truncation=True, max_length=128, padding="max_length")

    tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
    
    original_columns = raw_datasets['train'].column_names
    
    # Remove all original columns (including 'idx') except for 'label'.
    columns_to_remove = [col for col in original_columns if col != 'label']
    
    tokenized_datasets = tokenized_datasets.remove_columns(columns_to_remove)
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    tokenized_datasets.set_format("torch")

    train_dataset = tokenized_datasets['train']
    eval_dataset = tokenized_datasets[validation_split]

    return train_dataset, eval_dataset, num_labels

def get_data(dataset_name, task_name, tokenizer, args=None):
    """
    Main function to load data.
    """
    if dataset_name == 'glue':
        train_dataset, eval_dataset, num_labels = _get_glue_dataset(task_name, tokenizer, args)
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    else:
        raise ValueError(f"Dataset collection '{dataset_name}' not recognized.")
        
    return train_dataset, eval_dataset, num_labels, data_collator

def _stratified_sample(dataset, num_samples_to_select, num_labels):
    """
    Performs stratified sampling from a HuggingFace dataset to ensure balanced class representation.
    """
    labels = np.array(dataset['labels'])
    indices_by_label = {label: np.where(labels == label)[0] for label in range(num_labels)}

    base_samples_per_class = num_samples_to_select // num_labels
    remainder = num_samples_to_select % num_labels

    all_sampled_indices = []
    
    for label in range(num_labels):
        num_for_this_class = base_samples_per_class + (1 if label < remainder else 0)
        available_indices = indices_by_label.get(label, np.array([]))
        
        if len(available_indices) == 0:
            print(f"Warning: No samples found for class {label}. Skipping.")
            continue
        
        if len(available_indices) < num_for_this_class:
            print(f"Warning: Class {label} has only {len(available_indices)} samples, but {num_for_this_class} were requested. Taking all.")
            sampled_indices = available_indices
        else:
            sampled_indices = np.random.choice(available_indices, num_for_this_class, replace=False)
            
        all_sampled_indices.extend(sampled_indices.tolist())

    np.random.shuffle(all_sampled_indices)
    return all_sampled_indices

def _random_sample(dataset, num_samples_to_select):
    """ Performs simple random sampling from a dataset. """
    total_size = len(dataset)
    all_indices = np.arange(total_size)
    sampled_indices = np.random.choice(all_indices, size=num_samples_to_select, replace=False)
    return sampled_indices.tolist()

def create_dataloaders_and_subsets(train_dataset, eval_dataset, args, data_collator, num_labels):
    """
    Creates subsets and DataLoaders.
    - Training set uses stratified sampling.
    - Validation set uses simple random sampling.
    """
    num_total_train = len(train_dataset)
    num_train_samples = min(args.train_set_size, num_total_train)
    if args.train_set_size > num_total_train:
        print(f"Warning: Requested train_set_size ({args.train_set_size}) > dataset size ({num_total_train}). Using full set.")
    train_indices = _stratified_sample(train_dataset, num_train_samples, num_labels)
    train_subset = Subset(train_dataset, train_indices)

    num_total_eval = len(eval_dataset)
    num_eval_samples = min(args.val_set_size, num_total_eval)
    if args.val_set_size > num_total_eval:
        print(f"Warning: Requested val_set_size ({args.val_set_size}) > dataset size ({num_total_eval}). Using full set.")
    eval_indices = _random_sample(eval_dataset, num_eval_samples)
    eval_subset = Subset(eval_dataset, eval_indices)

    train_loader = DataLoader(
        train_subset, shuffle=True, batch_size=args.batch_size, collate_fn=data_collator
    )
    eval_loader = DataLoader(
        eval_subset, batch_size=args.batch_size, collate_fn=data_collator
    )
    
    print(f"Created a training set with {len(train_subset)} samples (stratified).")
    print(f"Created a validation set with {len(eval_subset)} samples (random).")
    
    return train_loader, eval_loader, train_subset, eval_subset
