"""
data_loading.py - Data loading and preprocessing utilities.

This module provides functions for:
- Loading preprocessed numpy arrays (token ids, attention masks, labels)
- Creating TensorFlow datasets for training, validation, and OOD evaluation
- Data balance printing utilities
"""

import os
import numpy as np
import tensorflow as tf

from modules.config import (
    DATA_DIR,
    BATCH_SIZE,
    MAX_LEN,
    VOCAB_SIZE,
    OOD_SLICE,
    set_seed,
    SEED,

)


# DATA LOADING FUNCTIONS

def load_arrays(prefix, data_dir=DATA_DIR):
    """
    Load preprocessed numpy arrays for a dataset split.

    Parameters
    ----------
    prefix : str
        Dataset prefix ('train', 'dev', or 'ood')
    data_dir : str
        Directory containing the preprocessed data files

    Returns
    -------
    tuple
        (ids, mask, labels) as numpy arrays with int32 dtype

    Raises
    ------
    FileNotFoundError
        If the data files are not found
    """
    ids_path = f"{data_dir}/{prefix}_ids.npy"
    if not os.path.exists(ids_path):
        raise FileNotFoundError(
            f"Could not find {ids_path}. "
            f"Ensure the data folder '{data_dir}' exists and contains the preprocessed files."
        )

    ids = np.load(f"{data_dir}/{prefix}_ids.npy").astype(np.int32)
    mask = np.load(f"{data_dir}/{prefix}_mask.npy").astype(np.int32)
    labels = np.load(f"{data_dir}/{prefix}_labels.npy").astype(np.int32)

    return ids, mask, labels


def create_dataset(ids, masks, labels, batch_size=BATCH_SIZE, is_training=False, seed=SEED):
    """
    Create a TensorFlow dataset from numpy arrays.

    Parameters
    ----------
    ids : np.ndarray
        Token IDs array of shape (N, seq_len)
    masks : np.ndarray
        Attention mask array of shape (N, seq_len)
    labels : np.ndarray
        Labels array of shape (N,)
    batch_size : int
        Batch size (default: from config)
    is_training : bool
        Whether this is a training dataset (enables shuffling)
    seed : int
        Random seed for shuffling (for reproducibility)

    Returns
    -------
    tf.data.Dataset
        Batched and prefetched TensorFlow dataset
    """
    set_seed(seed)
    dataset = tf.data.Dataset.from_tensor_slices((
        {
            "input_ids": tf.cast(ids, tf.int32),
            "attention_mask": tf.cast(masks, tf.int32)
        },
        labels
    ))

    # Don't cache - numpy arrays already in RAM, caching here doubles memory usage
    # dataset = dataset.cache()

    if is_training:
        set_seed(seed)
        #  Must pass seed explicitly to shuffle() for reproducibility!
        dataset = dataset.shuffle(len(ids), seed=seed, reshuffle_each_iteration=True)

    # drop_remainder=True is vital for stability (especially for TPUs/XLA)
    dataset = dataset.batch(batch_size, drop_remainder=True)

    set_seed()
    # Prefetch always goes last to buffer the NEXT batch while GPU works on current
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset


def load_all_datasets(data_dir=DATA_DIR, batch_size=BATCH_SIZE, seed=SEED):
    """
    Load all datasets (train, validation/dev, OOD).

    Parameters
    ----------
    data_dir : str
        Directory containing preprocessed data
    batch_size : int
        Batch size for all datasets
    seed : int
        Random seed for shuffling ( for reproducibility)

    Returns
    -------
    dict
        Dictionary containing:
        - 'train_dataset': Training TF dataset
        - 'val_dataset': Validation TF dataset (SST-2 dev)
        - 'ood_dataset': OOD TF dataset (IMDB)
        - 'train_ids', 'train_mask', 'train_labels': Raw training arrays
        - 'dev_ids', 'dev_mask', 'dev_labels': Raw validation arrays
        - 'ood_ids', 'ood_mask', 'ood_labels': Raw OOD arrays
        - 'vocab_size': Vocabulary size
        - 'num_batches': Number of training batches
        - 'kl_weight': Computed KL weight for Bayesian models
    """
    print(f"Loading data from {data_dir}... (seed={seed})")

    # Load In-Distribution (SST-2)
    train_ids, train_mask, train_labels = load_arrays("train", data_dir)
    dev_ids, dev_mask, dev_labels = load_arrays("dev", data_dir)

    # Load Out-of-Distribution (IMDb)
    ood_ids, ood_mask, ood_labels = load_arrays("ood", data_dir)

    print(f"Train (SST-2): {train_ids.shape}")
    print(f"Dev (SST-2):   {dev_ids.shape}")
    print(f"OOD (IMDb):    {ood_ids.shape}")

    # Create datasets - passing seed for reproducibility
    train_dataset = create_dataset(
        train_ids, train_mask, train_labels,
        batch_size=batch_size, is_training=True, seed=seed
    )
    val_dataset = create_dataset(
        dev_ids, dev_mask, dev_labels,
        batch_size=batch_size, is_training=False, seed=seed
    )
    ood_dataset = create_dataset(
        ood_ids, ood_mask, ood_labels,
        batch_size=batch_size, is_training=False, seed=seed
    )

    # Compute KL weight
    num_batches = len(train_labels) // batch_size
    kl_weight = 0.001 / num_batches

    print(f"Ready. KL Weight: {kl_weight:.6f}")

    return {
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'ood_dataset': ood_dataset,
        'train_ids': train_ids,
        'train_mask': train_mask,
        'train_labels': train_labels,
        'dev_ids': dev_ids,
        'dev_mask': dev_mask,
        'dev_labels': dev_labels,
        'ood_ids': ood_ids,
        'ood_mask': ood_mask,
        'ood_labels': ood_labels,
        'vocab_size': VOCAB_SIZE,
        'num_batches': num_batches,
        'kl_weight': kl_weight,
    }


def prepare_evaluation_data(dev_ids, dev_mask, dev_labels, ood_ids, ood_mask, ood_labels,
                            ood_slice=OOD_SLICE):
    """
    Prepare data dictionaries for evaluation.

    Parameters
    ----------
    dev_ids, dev_mask, dev_labels : np.ndarray
        Validation data arrays
    ood_ids, ood_mask, ood_labels : np.ndarray
        OOD data arrays
    ood_slice : int
        Number of OOD samples to use (to prevent OOM)

    Returns
    -------
    tuple
        (X_test_dict, y_test_arr, X_ood_dict, y_ood_arr)
    """
    X_test_dict = {
        'input_ids': dev_ids,
        'attention_mask': dev_mask
    }
    y_test_arr = dev_labels

    # Slice OOD data to prevent OOM
    X_ood_dict = {
        'input_ids': ood_ids[:ood_slice],
        'attention_mask': ood_mask[:ood_slice]
    }
    y_ood_arr = ood_labels[:ood_slice]

    return X_test_dict, y_test_arr, X_ood_dict, y_ood_arr


# DATA BALANCE UTILITIES


def print_balance(name, labels):
    """
    Print class balance statistics for a dataset.

    Parameters
    ----------
    name : str
        Name of the dataset (for display)
    labels : np.ndarray
        Labels array
    """
    unique, counts = np.unique(labels, return_counts=True)
    total = len(labels)
    print(f"--- {name} ({total} samples) ---")
    for cls, count in zip(unique, counts):
        percentage = (count / total) * 100
        label_name = "Positive" if cls == 1 else "Negative"
        print(f"  Class {cls} ({label_name}): {count} ({percentage:.1f}%)")


def print_all_balances(train_labels, dev_labels, ood_labels):
    """
    Print class balance for all dataset splits.

    Parameters
    ----------
    train_labels, dev_labels, ood_labels : np.ndarray
        Labels arrays for each split
    """
    print_balance("Training Set", train_labels)
    print_balance("Validation Set (Dev)", dev_labels)
    print_balance("OOD Set (IMDB)", ood_labels)



