"""
Training Functions for Bayesian Neural Networks

This module implements training routines for:
- Multiple Bayesian model architectures
- Deep ensembles
- Hyperparameter configuration and KL warmup
"""

import time
import random
import gc
import numpy as np
import tensorflow as tf
from modules.model_builders import (
    build_dense_model,
    build_lowrank_gauss,
    set_kl_scale,
    compile_binary
)



def train_models(configs, X_train, y_train, X_test, y_test, feature_dim, class_weight,
                 include_ensemble=True, n_ensemble_members=5, seed=42):
    """
    Train all Bayesian model architectures and optionally a deep ensemble.

    Parameters:
    -----------
    X_train : pandas DataFrame, training features
    y_train : pandas Series, training labels
    X_test : pandas DataFrame, test features
    y_test : pandas Series, test labels
    feature_dim : int, number of input features
    class_weight : dict, class weights for imbalanced data
    include_ensemble : bool, whether to train deep ensemble (default: True)
    n_ensemble_members : int, number of ensemble members (default: 5)
    seed : int, random seed (default: 42)

    Returns:
    --------
    histories : dict, training history for each model
    times : dict, training time for each model
    trained_models : dict, trained models (ensemble is stored as a list)
    """
    # Prepare training/validation arrays
    x_tr = X_train.values
    y_tr = y_train.values
    x_te = X_test.values
    y_te = y_test.values

    # Determine KL scaling factors
    batch_size = 64

    histories = {}
    times = {}
    trained_models = {}

    # Iterate over model configurations and train each separately
    for name, cfg in configs.items():
        # Clear any existing session to avoid cross-model interference
        tf.keras.backend.clear_session()
        gc.collect()
        tf.random.set_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        tf.keras.utils.set_random_seed(seed)

        # Build model
        model = cfg["builder"](feature_dim)
        # Apply KL scaling
        set_kl_scale(model, cfg["kl_scale"])
        compile_binary(model)
        # Print summary
        print(f"\n=== {name} ===")
        model.summary(line_length=120, expand_nested=False)

        # Train the model
        start = time.time()
        print(f"Training {name}")
        history = model.fit(
            x_tr, y_tr,
            validation_data=(x_te, y_te),
            epochs=cfg["epochs"],
            batch_size=batch_size,
            class_weight=class_weight,
            shuffle=True,
            verbose=1,
            callbacks=[]
        )
        elapsed = time.time() - start
        times[name] = elapsed
        histories[name] = history.history
        print(f"Training time for {name}: {elapsed/60:.2f} min ({elapsed:.1f} s)")
        trained_models[name] = model

    # Train deep ensemble if requested
    if include_ensemble:
        print(f"\n{'='*80}")
        print("TRAINING DEEP ENSEMBLE")
        print(f"{'='*80}")
        ensemble_models, ensemble_histories, ensemble_times = train_deep_ensemble(
            X_train, y_train, X_test, y_test, feature_dim,
            n_members=n_ensemble_members,
            epochs=32,  # overfit when we use more the same reason the deterministic model is limited to 32(inspired from David Ruhe et al in their mimic paper )
            batch_size=batch_size,
            class_weight=class_weight,
            seed=seed
        )
        # Store ensemble as a list of models
        trained_models["Deep Ensemble"] = ensemble_models
        histories["Deep Ensemble"] = ensemble_histories
        times["Deep Ensemble"] = sum(ensemble_times)
        print(f"Total ensemble training time: {sum(ensemble_times)/60:.2f} min ({sum(ensemble_times):.1f} s)")

    return histories, times, trained_models


def train_deep_ensemble(X_train, y_train, X_val, y_val, feature_dim: int,
                        n_members: int = 5, epochs: int = 32,
                        batch_size: int = 128, class_weight: dict = None,
                        seed: int = 42):
    """
    Train n_members deterministic neural networks with different random seeds.

    Parameters:
    -----------
    X_train : pandas DataFrame, training features
    y_train : pandas Series, training labels
    X_val : pandas DataFrame, validation features
    y_val : pandas Series, validation labels
    feature_dim : int, number of input features
    n_members : int, number of ensemble members (default: 5)
    epochs : int, number of training epochs (default: 32)
    batch_size : int, batch size for training (default: 128)
    class_weight : dict, class weights for imbalanced data
    seed : int, base random seed (default: 42)

    Returns:
    --------
    ensemble_models : list of trained models
    histories : list of training histories
    times : list of training times for each member
    """
    ensemble_models = []
    histories = []
    times = []

  

    for m in range(n_members):
        # Clear session to avoid cross-model interference
        tf.keras.backend.clear_session()
        gc.collect()
        # Set a different seed for each member to encourage diversity
        tf.keras.utils.set_random_seed(seed + m)
        np.random.seed(seed + m)
        # Build and train model
        model = build_dense_model(feature_dim)
        print(f"Training ensemble member {m+1}/{n_members}")
        start = time.time()
        history = model.fit(
            X_train.values, y_train.values,
            validation_data=(X_val.values, y_val.values),
            epochs=epochs,
            batch_size=batch_size,
            class_weight=class_weight,
            shuffle=True,
            verbose=1,
            callbacks=[]
        )
        elapsed = time.time() - start
        ensemble_models.append(model)
        histories.append(history.history)
        times.append(elapsed)
        print(f"Training time for ensemble member {m+1}: {elapsed/60:.2f} min ({elapsed:.1f} s)")

    return ensemble_models, histories, times
def train_lowrank_bayes_ensemble(
    n_members,
    X_train, y_train,
    X_val, y_val,
    feature_dim,
    class_weight,
    rank1, rank2,
    base_seed=42,
    epochs=256,
    batch_size=128,
    det_model=None,
    callback=[], # optional backbone for init
):
    members = []
    for k in range(n_members):
        seed = base_seed + k
        tf.keras.backend.clear_session()
        gc.collect()
        tf.random.set_seed(seed)
        np.random.seed(seed)

        model = build_lowrank_gauss(
            input_dim=feature_dim,
            rank1=rank1,
            rank2=rank2,
            init_from_deterministic=det_model,
        )
 

        history = model.fit(
            X_train.values, y_train.values,
            validation_data=(X_val.values, y_val.values),
            epochs=epochs,
            batch_size=batch_size,
            class_weight=class_weight,
            shuffle=True,
            verbose=0,
            callbacks= callback
        )
        members.append(model)
    return members