"""
Data Loading and Preprocessing Utilities for Bayesian LSTM

This module contains all data-related functions:
- Data loading
- Scaling and preprocessing
- OOD data loading
"""

import os
import numpy as np
import joblib
from sklearn.preprocessing import StandardScaler


def load_beijing_data(data_dir="beijing_data"):
    """
    Load raw Beijing PM2.5 data from .npy files.

    Args:
        data_dir: Directory containing the data files

    Returns:
        X_train_raw, X_val_raw, X_test_raw: Raw feature sequences (N, T, F)
        y_train_raw, y_val_raw, y_test_raw: Raw target values (N, 1)
        meta: Metadata dictionary
    """
    # Load raw sequences (N, T, F)
    X_train_raw = np.load(os.path.join(data_dir, "X_train_raw.npy"))
    X_val_raw = np.load(os.path.join(data_dir, "X_val_raw.npy"))
    X_test_raw = np.load(os.path.join(data_dir, "X_test_raw.npy"))
    y_train_raw = np.load(os.path.join(data_dir, "y_train_raw.npy"))
    y_val_raw = np.load(os.path.join(data_dir, "y_val_raw.npy"))
    y_test_raw = np.load(os.path.join(data_dir, "y_test_raw.npy"))

    # Load metadata
    meta = joblib.load(os.path.join(data_dir, "metadata.pkl"))

    print("Raw data shapes:")
    print(f"  X_train_raw: {X_train_raw.shape}")
    print(f"  y_train_raw: {y_train_raw.shape}")

    return (X_train_raw, X_val_raw, X_test_raw,
            y_train_raw, y_val_raw, y_test_raw,
            meta)


def preprocess_data(X_train_raw, X_val_raw, X_test_raw,
                    y_train_raw, y_val_raw, y_test_raw,
                    data_dir="beijing_data"):
    """
    Scale and preprocess the data for LSTM input.

    Args:
        X_train_raw, X_val_raw, X_test_raw: Raw feature sequences
        y_train_raw, y_val_raw, y_test_raw: Raw target values
        data_dir: Directory to save scalers

    Returns:
        X_train, X_val, X_test: Scaled feature sequences (N, T, F)
        y_train, y_val, y_test: Scaled target values (N, 1)
        scaler_X, scaler_y: Fitted scalers
    """
    N_train, T, F = X_train_raw.shape

    # Flatten CORRECTLY: (N, T, F) -> (N*T, F)
    # Each timestep is treated as a separate sample, features scaled consistently
    X_train_flat = X_train_raw.reshape(-1, F)
    X_val_flat = X_val_raw.reshape(-1, F)
    X_test_flat = X_test_raw.reshape(-1, F)

    # Initialize scalers
    scaler_X = StandardScaler()
    scaler_y = StandardScaler()

    # Fit on training data only
    scaler_X.fit(X_train_flat)
    scaler_y.fit(y_train_raw)

    # Transform all sets
    X_train_scaled_flat = scaler_X.transform(X_train_flat)
    X_val_scaled_flat = scaler_X.transform(X_val_flat)
    X_test_scaled_flat = scaler_X.transform(X_test_flat)

    y_train = scaler_y.transform(y_train_raw)
    y_val = scaler_y.transform(y_val_raw)
    y_test = scaler_y.transform(y_test_raw)

    # Reshape back to (N, T, F) for LSTM input
    X_train = X_train_scaled_flat.reshape(N_train, T, F)
    X_val = X_val_scaled_flat.reshape(X_val_raw.shape[0], T, F)
    X_test = X_test_scaled_flat.reshape(X_test_raw.shape[0], T, F)

    print("\nScaled data shapes:")
    print(f"  X_train for LSTM: {X_train.shape}")
    print(f"  y_train: {y_train.shape}")
    print(f"  X_val for LSTM: {X_val.shape}")
    print(f"  y_val: {y_val.shape}")
    print(f"  X_test for LSTM: {X_test.shape}")
    print(f"  y_test: {y_test.shape}")

    # Save scalers for inverse transform later
    joblib.dump(scaler_X, os.path.join(data_dir, "scaler_X.pkl"))
    joblib.dump(scaler_y, os.path.join(data_dir, "scaler_y.pkl"))
    print("\nScalers saved")

    return (X_train, X_val, X_test,
            y_train, y_val, y_test,
            scaler_X, scaler_y)


def load_and_preprocess_data(data_dir="beijing_data"):
    """
    Convenience function to load and preprocess data in one step.

    Args:
        data_dir: Directory containing the data files

    Returns:
        X_train, X_val, X_test: Scaled feature sequences
        y_train, y_val, y_test: Scaled target values
        scaler_X, scaler_y: Fitted scalers
        meta: Metadata dictionary
    """
    # Load raw data
    (X_train_raw, X_val_raw, X_test_raw,
     y_train_raw, y_val_raw, y_test_raw,
     meta) = load_beijing_data(data_dir)

    # Preprocess
    (X_train, X_val, X_test,
     y_train, y_val, y_test,
     scaler_X, scaler_y) = preprocess_data(
        X_train_raw, X_val_raw, X_test_raw,
        y_train_raw, y_val_raw, y_test_raw,
        data_dir
    )

    return (X_train, X_val, X_test,
            y_train, y_val, y_test,
            scaler_X, scaler_y, meta)


def load_ood_data(scaler_X, scaler_y, data_dir="beijing_data"):
    """
    Load and preprocess OOD (out-of-distribution) data.

    Args:
        scaler_X: Fitted feature scaler
        scaler_y: Fitted target scaler
        data_dir: Directory containing the OOD data files

    Returns:
        X_ood: Scaled OOD features (N, T, F)
        y_ood: Scaled OOD targets (N, 1)
        y_ood_original: OOD targets in original scale
    """
    X_ood_raw = np.load(os.path.join(data_dir, "X_raw_ood.npy"))
    y_ood_raw = np.load(os.path.join(data_dir, "y_raw_ood.npy"))

    N_ood, T_ood, F_ood = X_ood_raw.shape

    # Flatten and scale
    X_ood_flat = X_ood_raw.reshape(-1, F_ood)
    X_ood_scaled_flat = scaler_X.transform(X_ood_flat)
    X_ood = X_ood_scaled_flat.reshape(N_ood, T_ood, F_ood)

    y_ood = scaler_y.transform(y_ood_raw)
    y_ood_original = scaler_y.inverse_transform(y_ood).flatten()

    print(f"OOD X shape: {X_ood.shape}")
    print(f"OOD y shape: {y_ood.shape}")

    return X_ood, y_ood, y_ood_original


def load_scalers(data_dir="beijing_data"):
    """
    Load previously saved scalers.

    Args:
        data_dir: Directory containing the scaler files

    Returns:
        scaler_X, scaler_y: Loaded scalers
    """
    scaler_X = joblib.load(os.path.join(data_dir, "scaler_X.pkl"))
    scaler_y = joblib.load(os.path.join(data_dir, "scaler_y.pkl"))
    return scaler_X, scaler_y
