# loader.py

import os
import csv
import numpy as np
import tensorflow as tf
from collections import defaultdict
import argparse


# -------------------------------
# 1. Directory to Label Mapping
# -------------------------------

alpha_map = {"0.2": 0, "0.3": 1, "0.4": 2}
q0_map = {"1": 0, "1.5": 1, "2.0": 2, "2.5": 3}


def parse_labels_from_dir(dir_name):
    """Parse directory name into label tuple (energy_loss, alpha_s, q0)."""
    energy_loss_str, alpha_str, q0_str = dir_name.split('_')
    energy_loss = 0 if energy_loss_str == "MMAT" else 1
    alpha = alpha_map[alpha_str]
    q0 = q0_map[q0_str]
    return (energy_loss, alpha, q0)


# -----------------------------------------
# 2. File and Label Generator (Python level)
# -----------------------------------------

def file_label_generator(root_dir):
    """Yield (file_path, label_tuple) for all files in dataset."""
    for dir_name in os.listdir(root_dir):
        dir_path = os.path.join(root_dir, dir_name)
        if os.path.isdir(dir_path):
            label_tuple = parse_labels_from_dir(dir_name)
            for file_name in os.listdir(dir_path):
                if file_name.endswith(".npy"):
                    file_path = os.path.join(dir_path, file_name)
                    yield file_path, label_tuple

def save_file_label_list(file_label_list, filename, root_dir):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['file_path', 'energy_loss', 'alpha', 'q0'])
        for file_path, (energy_loss, alpha, q0) in file_label_list:
            relative_path = os.path.relpath(file_path, root_dir)
            writer.writerow([relative_path, energy_loss, alpha, q0])

def load_file_label_list(filename, root_dir):
    result = []
    with open(filename, 'r') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            relative_path = row['file_path']
            absolute_path = os.path.join(root_dir, relative_path)
            label = (int(row['energy_loss']), int(row['alpha']), int(row['q0']))
            result.append((absolute_path, label))
    return result


# ----------------------------------------------------------
# 3. Stratified Split for Train/Val/Test (Balanced and Stable)
# ----------------------------------------------------------

def split_file_list(file_label_list, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_seed=42):
    """Stratified split on (file_path, label_tuple)."""
    np.random.seed(random_seed)
    label_to_files = defaultdict(list)
    for file_path, label in file_label_list:
        label_to_files[label].append(file_path)

    train_list, val_list, test_list = [], [], []

    for label, files in label_to_files.items():
        files = np.array(files)
        np.random.shuffle(files)
        total = len(files)
        train_end = int(train_ratio * total)
        val_end = train_end + int(val_ratio * total)

        if total < 3:
            train_split = files[:1]
            val_split = files[1:2] if total > 1 else []
            test_split = files[2:] if total > 2 else []
        else:
            train_split = files[:train_end]
            val_split = files[train_end:val_end]
            test_split = files[val_end:]

        train_list.extend([(fp, label) for fp in train_split])
        val_list.extend([(fp, label) for fp in val_split])
        test_list.extend([(fp, label) for fp in test_split])

    np.random.shuffle(train_list)
    np.random.shuffle(val_list)
    np.random.shuffle(test_list)

    return train_list, val_list, test_list


# ---------------------------------------------------
# 4. TensorFlow Dataset Generator (Lazy loading .npy)
# ---------------------------------------------------

def tf_dataset_generator(file_label_list, global_max):
    """TensorFlow-compatible generator yielding normalized events and multi-output labels."""
    for file_path, label in file_label_list:
        event = np.load(file_path).astype(np.float32)
        event = event / global_max  # Normalize to [0, 1]
        event = np.expand_dims(event, axis=-1)  # Shape: (32, 32, 1)

        energy_loss_label = np.array([label[0]], dtype=np.float32)  # (1,)
        alpha_label = tf.one_hot(label[1], depth=3, dtype=tf.float32)  # (3,)
        q0_label = tf.one_hot(label[2], depth=4, dtype=tf.float32)  # (4,)

        yield event, {
            'energy_loss_output': energy_loss_label,
            'alpha_output': alpha_label,
            'q0_output': q0_label
        }


# --------------------------------------------
# 5. TensorFlow Dataset Pipeline Construction
# --------------------------------------------

def build_tf_dataset(file_label_list, global_max, batch_size=512, buffer_size=10000, shuffle=True):
    """Build TensorFlow Dataset pipeline with shuffling, batching, and prefetching."""
    dataset = tf.data.Dataset.from_generator(
        lambda: tf_dataset_generator(file_label_list, global_max),
        output_signature=(
            tf.TensorSpec(shape=(32, 32, 1), dtype=tf.float32),
            {
                'energy_loss_output': tf.TensorSpec(shape=(1,), dtype=tf.float32),
                'alpha_output': tf.TensorSpec(shape=(3,), dtype=tf.float32),
                'q0_output': tf.TensorSpec(shape=(4,), dtype=tf.float32)
            }
        )
    )

    if shuffle:
        dataset = dataset.shuffle(buffer_size=buffer_size)

    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

# -------------------------------
# Split Saving/Loading Utilities
# -------------------------------

def save_split_to_csv(file_label_list, filename, root_dir):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['file_path', 'energy_loss', 'alpha', 'q0'])
        for file_path, (energy_loss, alpha, q0) in file_label_list:
            relative_path = os.path.relpath(file_path, root_dir)  # Make path relative
            writer.writerow([relative_path, energy_loss, alpha, q0])


def load_split_from_csv(filename, root_dir):
    result = []
    with open(filename, 'r') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            relative_path = row['file_path']
            absolute_path = os.path.join(root_dir, relative_path)  # Rebuild full path
            label = (int(row['energy_loss']), int(row['alpha']), int(row['q0']))
            result.append((absolute_path, label))
    return result

def calculate_global_min_with_tf(root_dir):
    """Calculate global minimum value in the dataset using TensorFlow and GPU acceleration."""
    # Ensure TensorFlow is using GPU
    device_name = tf.config.list_physical_devices('GPU')
    if not device_name:
        print("No GPU detected, defaulting to CPU.")
        return None

    global_min = float('inf')  # Start with an infinitely large value

    for dir_name in os.listdir(root_dir):
        dir_path = os.path.join(root_dir, dir_name)
        if os.path.isdir(dir_path):
            for file_name in os.listdir(dir_path):
                if file_name.endswith(".npy"):
                    file_path = os.path.join(dir_path, file_name)
                    
                    # Load the event data from the .npy file
                    event_data = np.load(file_path).astype(np.float32)
                    
                    # Move data to GPU (if available)
                    event_data_tensor = tf.convert_to_tensor(event_data)
                    
                    # Calculate the minimum on the GPU
                    file_min = tf.reduce_min(event_data_tensor).numpy()  # .numpy() to bring it back to CPU

                    # Update the global minimum
                    if file_min < global_min:
                        global_min = file_min

    return global_min


# -------------------------------
# 6. Main Function for Testing
# -------------------------------
def main():
    import argparse
    parser = argparse.ArgumentParser(description="TensorFlow DataLoader for ML-JET dataset with smart caching and splits")
    parser.add_argument('--root_dir', type=str, required=True, help='Path to dataset root directory')
    parser.add_argument('--global_max', type=float, required=True, help='Global max for normalization')
    parser.add_argument('--batch_size', type=int, default=512, help='Batch size for DataLoader')
    parser.add_argument('--buffer_size', type=int, default=10000, help='Shuffle buffer size')
    parser.add_argument('--random_seed', type=int, default=42, help='Random seed for reproducibility')
    args = parser.parse_args()

    # # Calculate global minimum using TensorFlow GPU (if available)
    # print("[INFO] Calculating global minimum using TensorFlow (GPU accelerated)...")
    # global_min = calculate_global_min_with_tf(args.root_dir)
    
    # if global_min is None:
    #     print("[ERROR] Failed to calculate global minimum.")
    #     return
    
    # print(f"[INFO] Global minimum value in dataset: {global_min}")
    # return
    
    # File names inside dataset root
    train_file = os.path.join(args.root_dir, "train_files.csv")
    val_file = os.path.join(args.root_dir, "val_files.csv")
    test_file = os.path.join(args.root_dir, "test_files.csv")
    file_label_cache = os.path.join(args.root_dir, "file_labels.csv")

    # -------------------------------
    # Priority Check 1: Splits exist?
    # -------------------------------
    if os.path.exists(train_file) and os.path.exists(val_file) and os.path.exists(test_file):
        print(f"[INFO] Found existing splits in '{args.root_dir}'. Loading splits directly...")
        train_list = load_split_from_csv(train_file, args.root_dir)
        val_list = load_split_from_csv(val_file, args.root_dir)
        test_list = load_split_from_csv(test_file, args.root_dir)

    else:
        print(f"[INFO] Splits not found. Checking for cached file-label list...")

        # -------------------------------
        # Priority Check 2: File label list exists?
        # -------------------------------
        if os.path.exists(file_label_cache):
            print(f"[INFO] Found cached file-label list '{file_label_cache}'.")
            file_label_list = load_file_label_list(file_label_cache, args.root_dir)
        else:
            print(f"[INFO] Cached file-label list not found. Scanning dataset directory to generate...")
            file_label_list = list(file_label_generator(args.root_dir))
            print(f"[INFO] Total files found: {len(file_label_list)}")
            save_file_label_list(file_label_list, file_label_cache, args.root_dir)
            print(f"[INFO] File-label list cached to '{file_label_cache}'.")

        # Now split the loaded/generated file-label list
        print("[INFO] Performing stratified split...")
        train_list, val_list, test_list = split_file_list(file_label_list, random_seed=args.random_seed)

        print(f"Training set size: {len(train_list)}")
        print(f"Validation set size: {len(val_list)}")
        print(f"Test set size: {len(test_list)}")

        # Save splits for future use
        save_split_to_csv(train_list, train_file, args.root_dir)
        save_split_to_csv(val_list, val_file, args.root_dir)
        save_split_to_csv(test_list, test_file, args.root_dir)
        print(f"[INFO] Splits saved inside dataset root '{args.root_dir}'.")

    # -------------------------------
    # TensorFlow Dataset Pipeline
    # -------------------------------
    print("[INFO] Building TensorFlow datasets for training/validation/testing...")
    train_dataset = build_tf_dataset(train_list, args.global_max, batch_size=args.batch_size, buffer_size=args.buffer_size, shuffle=True)
    val_dataset = build_tf_dataset(val_list, args.global_max, batch_size=args.batch_size, buffer_size=args.buffer_size, shuffle=False)
    test_dataset = build_tf_dataset(test_list, args.global_max, batch_size=args.batch_size, buffer_size=args.buffer_size, shuffle=False)

    print("[INFO] Dataset pipeline built successfully. Example batch:")

    for x, y in train_dataset.take(1):
        print("Input batch shape:", x.shape)
        print("Energy Loss label batch shape:", y['energy_loss_output'].shape)
        print("Alpha label batch shape:", y['alpha_output'].shape)
        print("Q0 label batch shape:", y['q0_output'].shape)

    print("✅ DataLoader pipeline ready with smart caching and split management.")


# -------------------------------
# Entry Point for Command-Line
# -------------------------------

if __name__ == "__main__":
    main()

# Example usage:
#python data/loader.py --root_dir ~/data/jet_ml_benchmark_config_01_to_09_alpha_0.2_0.3_0.4_q0_1.5_2.0_2.5_MMAT_MLBT_size_1000_balanced_unshuffled --global_max 121.79151153564453 --batch_size 512