"""
Data loading utilities for Heart Disease dataset
"""
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
import os


def set_seed(seed=42):
    """Set random seed for reproducibility"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_heart_disease_data(data_path="heart_disease.csv", test_size=0.1, val_size_from_remaining=1/9, seed=42):
    """
    Load Heart Disease dataset with proper train/validation/test split

    Args:
        data_path: Path to the CSV file
        test_size: Fraction for test set (default: 0.1 for 10%)
        val_size_from_remaining: Fraction for validation from remaining data (default: 1/9 ≈ 0.111)
        seed: Random seed for reproducibility

    Returns:
        X_train, X_val, X_test, y_train, y_val, y_test, feature_names
        Split ratio: 8:1:1 (train:validation:test)
    """
    print("Loading Heart Disease dataset...")

    # Check if data file exists
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    # Load data
    data = pd.read_csv(data_path)
    print(f"Data shape: {data.shape}")

    # Separate features and target
    feature_names = [col for col in data.columns if col != 'target']
    X = data[feature_names].values
    y = data['target'].values

    print(f"Number of features: {len(feature_names)}")
    print(f"Number of samples: {len(y)}")
    print(f"Class distribution: {np.bincount(y)}")
    print(f"Feature names: {feature_names}")

    # First split: separate test set (10%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )

    # Second split: divide remaining into train and validation (80% vs 10%)
    # val_size_from_remaining = 1/9 ≈ 0.111 makes train:val:test = 8:1:1
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_from_remaining, random_state=seed, stratify=y_temp
    )

    # Standardize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)
    X_test = scaler.transform(X_test)

    # Convert to proper data types
    X_train = np.asarray(X_train, dtype=np.float32)
    X_val = np.asarray(X_val, dtype=np.float32)
    X_test = np.asarray(X_test, dtype=np.float32)
    y_train = np.asarray(y_train, dtype=np.int64)
    y_val = np.asarray(y_val, dtype=np.int64)
    y_test = np.asarray(y_test, dtype=np.int64)

    print(f"Train set: {X_train.shape}, Validation set: {X_val.shape}, Test set: {X_test.shape}")
    print(f"Train class distribution: {np.bincount(y_train)}")
    print(f"Validation class distribution: {np.bincount(y_val)}")
    print(f"Test class distribution: {np.bincount(y_test)}")

    return X_train, X_val, X_test, y_train, y_val, y_test, feature_names


def preprocess_features_by_groups(X_train, X_val, X_test, group_dict, device='cpu'):
    """
    Preprocess features according to group dictionary

    Args:
        X_train, X_val, X_test: Feature arrays
        group_dict: Dictionary mapping group names to feature indices
        device: PyTorch device

    Returns:
        Dictionaries of grouped data for train, val, test sets and feature info
    """
    group_train_data = {}
    group_val_data = {}
    group_test_data = {}
    group_feature_info = {}

    for group_name, feature_indices in group_dict.items():
        if len(feature_indices) == 0:
            continue

        # Extract features for this group
        X_train_group = X_train[:, feature_indices].copy()
        X_val_group = X_val[:, feature_indices].copy()
        X_test_group = X_test[:, feature_indices].copy()

        # Convert to tensors
        group_train_data[group_name] = torch.FloatTensor(X_train_group).to(device)
        group_val_data[group_name] = torch.FloatTensor(X_val_group).to(device)
        group_test_data[group_name] = torch.FloatTensor(X_test_group).to(device)

        # Store feature info
        group_feature_info[group_name] = {'input_dim': len(feature_indices)}

    return group_train_data, group_val_data, group_test_data, group_feature_info