#!/usr/bin/env python
# utils_data.py - Data loading and preprocessing utilities
# --------------------------------------------------------------------
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.io import arff
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torch.utils.data import TensorDataset, DataLoader
import torch

def load_arff_file(fp: Path):
    data, _ = arff.loadarff(str(fp))
    df = pd.DataFrame(data)
    for c in df.columns:
        if isinstance(df[c].iloc[0], (bytes, bytearray)):
            df[c] = df[c].str.decode('utf8')
    if 'class' in df.columns:
        y = df.pop('class')
    else:
        y = df.iloc[:, -1]
    y = y.astype('category').cat.codes.to_numpy(np.int64)
    return df, y

def preprocess(df: pd.DataFrame):
    num_cols = df.select_dtypes(exclude=['object', 'category']).columns
    cat_cols = df.select_dtypes(include=['object', 'category']).columns
    ct = ColumnTransformer([
        ('num', StandardScaler(), num_cols),
        ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
    ])
    X = ct.fit_transform(df)
    if hasattr(X, 'toarray'):
        X = X.toarray()
    return X.astype(np.float32)

def load_and_preprocess_data(data_root, dataset_name, fold):
    fp_tr = data_root / dataset_name / "weka" / f"train_{dataset_name}-{fold}.arff"
    fp_te = data_root / dataset_name / "weka" / f"test_{dataset_name}-{fold}.arff"
    
    if not (fp_tr.exists() and fp_te.exists()):
        raise FileNotFoundError(f"Missing files: {fp_tr} or {fp_te}")

    df_tr, y_tr = load_arff_file(fp_tr)
    df_te, y_te = load_arff_file(fp_te)
    
    num_cols = df_tr.select_dtypes(exclude=['object','category']).columns
    cat_cols = df_tr.select_dtypes(include=['object','category']).columns
    ct = ColumnTransformer([
        ('num', StandardScaler(), num_cols),
        ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
    ])
    Xtr = ct.fit_transform(df_tr)
    Xte = ct.transform(df_te)
    if hasattr(Xtr, 'toarray'): Xtr = Xtr.toarray()
    if hasattr(Xte, 'toarray'): Xte = Xte.toarray()
    Xtr = Xtr.astype(np.float32)
    Xte = Xte.astype(np.float32)
    
    return Xtr, y_tr, Xte, y_te

def prepare_dataloaders(Xtr, ytr, Xte, yte, batch_size=2048):
    trL = DataLoader(TensorDataset(torch.tensor(Xtr), torch.tensor(ytr)), batch_size, True)
    teL = DataLoader(TensorDataset(torch.tensor(Xte), torch.tensor(yte)), batch_size, False)
    return trL, teL