"""
training.py - Training utilities and high-level training loops.

This module provides:
- Training function for transformer models (baseline and Bayesian variants)
- Compilation utilities with appropriate optimizers
- Training time tracking
"""

import gc
import time
import numpy as np
import tensorflow as tf
import pandas as pd

from modules.config import (
    set_seed,
    LEARNING_RATE,
    LR_DECAY_ALPHA,
    WEIGHT_DECAY,
)
from modules.bayesian_layers import set_kl_scale, set_dropout_active
from modules.model_builders import clean_summary


# =============================================================================
# KL ANNEALING CALLBACK
# =============================================================================

class KLAnnealingCallback(tf.keras.callbacks.Callback):
    """
    Callback for KL annealing/warmup during training.

    Keeps KL at 0 for initial zero_epochs, then gradually increases KL scale
    from 0 to target value over the remaining warmup period.
    This helps the model first learn good representations before the KL
    regularization kicks in fully.

    Parameters
    ----------
    model : tf.keras.Model
        The Bayesian model with variational layers
    kl_scale_final : float
        Final KL scale to reach after warmup
    warmup_epochs : int
        Epoch number at which KL reaches full value (default: 5)
    zero_epochs : int
        Number of initial epochs with KL=0 (default: 2)
    verbose : int
        Verbosity level (0, 1, or 2)

    Example
    -------
    With zero_epochs=2 and warmup_epochs=5:
    - Epochs 0-1: KL = 0
    - Epoch 2: KL = 0 (start of ramp)
    - Epoch 3: KL = 0.33 * final
    - Epoch 4: KL = 0.67 * final
    - Epoch 5: KL = 1.0 * final
    - Epoch 6+: KL = final
    """

    def __init__(self, model, kl_scale_final, warmup_epochs=5, zero_epochs=2, verbose=1):
        super().__init__()
        self.target_model = model
        self.kl_scale_final = kl_scale_final
        self.warmup_epochs = warmup_epochs
        self.zero_epochs = zero_epochs
        self.verbose = verbose
        self.current_epoch = 0

    def _set_kl_scale(self, scale):
        """Set KL scale on all variational layers."""
        count = 0
        for layer in self.target_model.layers:
            if hasattr(layer, 'kl_scale'):
                layer.kl_scale = float(scale)
                count += 1
        return count

    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch

        if epoch < self.zero_epochs:
            # Keep KL at 0 for first zero_epochs
            current_kl = 0.0
        elif epoch <= self.warmup_epochs:
            # Linear increase from zero_epochs to warmup_epochs
            ramp_epochs = self.warmup_epochs - self.zero_epochs
            progress = (epoch - self.zero_epochs) / ramp_epochs
            current_kl = self.kl_scale_final * progress
        else:
            # After warmup, use full KL scale
            current_kl = self.kl_scale_final

        count = self._set_kl_scale(current_kl)

        if self.verbose >= 1:
            print(f"\n[KL Annealing] Epoch {epoch+1}: KL={current_kl:.6f} ({count} variational layers)")

    def on_train_end(self, logs=None):
        # Ensure final KL scale is set
        self._set_kl_scale(self.kl_scale_final)
        if self.verbose >= 1:
            print(f"\n[KL Annealing] Training complete. Final KL scale applied: {self.kl_scale_final:.6f}")


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_transformer_models(configs, train_dataset, val_dataset, ood_dataset, epochs=5):
    """
    Train and evaluate multiple transformer models.

    This is the main training function that handles both baseline and Bayesian
    transformer variants. It uses Adam optimizer (not AdamW) for Bayesian models
    to avoid weight decay on variational parameters.

    Parameters
    ----------
    configs : dict
        Mapping model name -> configuration with keys:
            - 'builder': function returning a new compiled Keras model (with kl_scale set to 0)
            - 'kl_scale': float or None. If not None, set this KL scale on all variational layers
            - 'epochs': int. Number of training epochs (optional, uses default if not specified)
            - 'use_kl_annealing': bool. If True, use KL annealing (default: False for backward compatibility)
            - 'kl_warmup_epochs': int. Epoch at which KL reaches full value (default: 5)
            - 'kl_zero_epochs': int. Number of initial epochs with KL=0 (default: 2)
    train_dataset : tf.data.Dataset
        Training dataset (SST-2 train)
    val_dataset : tf.data.Dataset
        In-distribution validation dataset (SST-2 dev)
    ood_dataset : tf.data.Dataset
        Out-of-distribution dataset (IMDB test)
    epochs : int
        Default number of epochs if not specified in configs

    Returns
    -------
    histories : dict
        Training histories for each model
    times : dict
        Training duration for each model (seconds)
    trained_models : dict
        Dictionary of trained model objects
    """
    histories = {}
    times = {}
    trained_models = {}

    for name, cfg in configs.items():
        # Clear previous graph/session and reset seeds
        tf.keras.backend.clear_session()
        gc.collect()
        set_seed()

        # Build model via builder closure
        model = cfg['builder']()
        set_dropout_active(model, active=True)

        # Handle KL scale and annealing
        kl_scale = cfg.get('kl_scale')
        use_kl_annealing = cfg.get('use_kl_annealing', False)

        if kl_scale is not None:
            if use_kl_annealing:
                # Start at 0, callback will handle ramping
                set_kl_scale(model, 0.0)
            else:
                # Set immediately to final value
                set_kl_scale(model, kl_scale)

        # Compile the model
        steps_per_epoch = len(list(train_dataset))
        n_epochs = cfg.get('epochs', epochs)
        total_steps = steps_per_epoch * n_epochs

        lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
            initial_learning_rate=LEARNING_RATE,
            decay_steps=total_steps,
            alpha=LR_DECAY_ALPHA
        )

        # Use Adam (not AdamW) for Bayesian models to avoid weight decay on variational params
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr_schedule,
            weight_decay=WEIGHT_DECAY
        )

        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='acc')],
            jit_compile=False
        )

        # Print summary
        print(f"\n{'='*70}")
        print(f"=== {name} ===")
        print(f"{'='*70}")
        if kl_scale is not None and use_kl_annealing:
            kl_warmup = cfg.get('kl_warmup_epochs', 5)
            kl_zero = cfg.get('kl_zero_epochs', 2)
            print(f"KL Annealing: ON (zero_epochs={kl_zero}, warmup_epochs={kl_warmup}, final_kl={kl_scale:.6f})")
        elif kl_scale is not None:
            print(f"KL Scale: {kl_scale:.6f} (no annealing)")
        print(f"{'='*70}")
        clean_summary(model)

        # Callbacks
        callbacks = []

        # KL Annealing callback (if enabled for Bayesian models)
        if kl_scale is not None and use_kl_annealing:
            kl_callback = KLAnnealingCallback(
                model=model,
                kl_scale_final=kl_scale,
                warmup_epochs=cfg.get('kl_warmup_epochs', 5),
                zero_epochs=cfg.get('kl_zero_epochs', 2),
                verbose=1
            )
            callbacks.append(kl_callback)

        # Early stopping
        early_stopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_acc',
            patience=5,
            restore_best_weights=True,
            verbose=1
        )
        callbacks.append(early_stopping)

        # Reduce LR on plateau
        reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_acc',
            factor=0.5,
            patience=3,
            verbose=1,
            min_lr=1e-6
        )
        callbacks.append(reduce_lr)

        # Train
        start = time.time()
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=n_epochs,
            callbacks=callbacks
        )
        times[name] = time.time() - start
        histories[name] = history.history
        trained_models[name] = model

        # Evaluate on ID and OOD
        set_dropout_active(model, active=False)
        print(f"\n{name} ID eval:")
        model.evaluate(val_dataset)
        print(f"{name} OOD eval:")
        model.evaluate(ood_dataset)

    return histories, times, trained_models


def train_baseline_model(builder, train_dataset, val_dataset, ood_dataset, epochs=5):
    """
    Train a single baseline transformer model.

    Uses AdamW optimizer with weight decay for deterministic baseline.

    Parameters
    ----------
    builder : callable
        Function that returns a new model
    train_dataset : tf.data.Dataset
        Training dataset
    val_dataset : tf.data.Dataset
        Validation dataset
    ood_dataset : tf.data.Dataset
        OOD evaluation dataset
    epochs : int
        Number of training epochs

    Returns
    -------
    tuple
        (model, history, training_time)
    """
    tf.keras.backend.clear_session()
    gc.collect()
    set_seed()

    model = builder()
    set_dropout_active(model, active=False)

    steps_per_epoch = len(list(train_dataset))
    total_steps = steps_per_epoch * epochs

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=LEARNING_RATE,
        decay_steps=total_steps,
        alpha=LR_DECAY_ALPHA
    )

    # Use AdamW for baseline (weight decay is beneficial)
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=lr_schedule,
        weight_decay=WEIGHT_DECAY
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='acc')],
        jit_compile=False
    )

    print("\n" + "="*70)
    print("=== Baseline Transformer ===")
    print("="*70)
    clean_summary(model)

    # Callbacks
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    )
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        verbose=1,
        min_lr=1e-6
    )

    start = time.time()
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[early_stopping, reduce_lr]
    )
    training_time = time.time() - start

    set_dropout_active(model, active=False)
    print("\nBaseline ID eval:")
    model.evaluate(val_dataset)
    print("Baseline OOD eval:")
    model.evaluate(ood_dataset)

    return model, history.history, training_time


def train_bayesian_model(builder, kl_weight, train_dataset, val_dataset, ood_dataset, epochs=15,
                         name="Bayesian Model", use_kl_annealing=True, kl_warmup_epochs=4, kl_zero_epochs=2):
    """
    Train a single Bayesian transformer model.

    Uses Adam optimizer (no AdamW) since weight decay interferes with variational
    inference optimization.

    Parameters
    ----------
    builder : callable
        Function that returns a new model
    kl_weight : float
        KL divergence weight for variational inference (final value after warmup)
    train_dataset : tf.data.Dataset
        Training dataset
    val_dataset : tf.data.Dataset
        Validation dataset
    ood_dataset : tf.data.Dataset
        OOD evaluation dataset
    epochs : int
        Number of training epochs
    name : str
        Model name for display
    use_kl_annealing : bool
        Whether to use KL annealing (default: True)
    kl_warmup_epochs : int
        Epoch at which KL reaches full value (default: 5)
    kl_zero_epochs : int
        Number of initial epochs with KL=0 (default: 2)

    Returns
    -------
    tuple
        (model, history, training_time)
    """
    tf.keras.backend.clear_session()
    gc.collect()
    set_seed()

    model = builder()
    set_dropout_active(model, active=False)

    # If using KL annealing, start with KL=0 (callback will handle ramping)
    # Otherwise, set the final KL weight immediately
    if use_kl_annealing:
        set_kl_scale(model, 0.0)  # Start at 0, will be increased by callback
    else:
        set_kl_scale(model, kl_weight)

    steps_per_epoch = len(list(train_dataset))
    total_steps = steps_per_epoch * epochs

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=LEARNING_RATE,
        decay_steps=total_steps,
        alpha=LR_DECAY_ALPHA
    )

    # Use Adam (not AdamW) for Bayesian models
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=lr_schedule,
        weight_decay=WEIGHT_DECAY
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='acc')],
        jit_compile=False
    )

    print("\n" + "="*70)
    print(f"=== {name} ===")
    print("="*70)
    if use_kl_annealing:
        print(f"KL Annealing: ON (zero_epochs={kl_zero_epochs}, warmup_epochs={kl_warmup_epochs}, final_kl={kl_weight:.6f})")
    else:
        print(f"KL Annealing: OFF (constant_kl={kl_weight:.6f})")
    print("="*70)
    clean_summary(model)

    # Callbacks
    callbacks = []

    # KL Annealing callback (if enabled)
    if use_kl_annealing:
        kl_callback = KLAnnealingCallback(
            model=model,
            kl_scale_final=kl_weight,
            warmup_epochs=kl_warmup_epochs,
            zero_epochs=kl_zero_epochs,
            verbose=1
        )
        callbacks.append(kl_callback)

    # Early stopping
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    )
    callbacks.append(early_stopping)

    # Reduce LR on plateau
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        verbose=1,
        min_lr=1e-6
    )
    callbacks.append(reduce_lr)

    start = time.time()
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks
    )
    training_time = time.time() - start

    set_dropout_active(model, active=False)
    print(f"\n{name} ID eval:")
    model.evaluate(val_dataset)
    print(f"{name} OOD eval:")
    model.evaluate(ood_dataset)

    return model, history.history, training_time


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def get_clean_metrics(models, dataset):
    """
    Get clean metrics (Accuracy, NLL, KL) for multiple models.

    Parameters
    ----------
    models : dict
        Dictionary mapping model names to models
    dataset : tf.data.Dataset
        Dataset to evaluate on

    Returns
    -------
    list
        List of dictionaries with metrics for each model
    """
   

    results = []
    print(f"{'Model':<20} | {'Acc (Mean)':<10} | {'NLL (Data Fit)':<15} | {'KL (Complexity)':<15}")
    print("-" * 70)

    for name, model in models.items():
        # Run standard evaluation
        total_loss, acc = model.evaluate(dataset, verbose=0)

        # Get predictions (Mean-Field / Deterministic)
        y_pred = model.predict(dataset, verbose=0)

        # Get True Labels (Iterate dataset once)
        y_true = np.concatenate([y.numpy() for x, y in dataset])

        # Calculate Pure NLL (Cross Entropy)
        cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
        nll_value = cce(y_true, y_pred).numpy()

        # Derive KL (Total Loss - NLL)
        if "Baseline" in name:
            kl_value = 0.0
        else:
            kl_value = total_loss - nll_value

        print(f"{name:<20} | {acc:.2%}     | {nll_value:.4f}          | {kl_value:.1f}")

        results.append({
            "Model": name,
            "Accuracy": acc,
            "NLL": nll_value,
            "KL": kl_value
        })

    return pd.DataFrame(results)


def compare_deterministic_accuracies(models, dataset):
    """
    Compare accuracies of models in deterministic mode (using posterior means).

    Parameters
    ----------
    models : dict
        Dictionary mapping model names to models
    dataset : tf.data.Dataset
        Dataset to evaluate on

    Returns
    -------
    pd.DataFrame
        DataFrame with comparison results
    """

    results = []

    print(f"{'Model Name':<25} | {'Acc (Mean-Only)':<18} | {'Loss':<10}")
    print("-" * 60)

    for name, model in models.items():
        # Keras automatically sets training=False here
        # This triggers the "Use Means" path for Bayesian models
        loss, acc = model.evaluate(dataset, verbose=0)

        results.append({
            "Model": name,
            "Mode": "Deterministic (Mean)",
            "Accuracy": f"{acc:.2%}",
            "Loss": f"{loss:.4f}"
        })

        print(f"{name:<25} | {acc:.2%}           | {loss:.4f}")

    return pd.DataFrame(results)


def print_training_times(times):
    """
    Print training times for all models.

    Parameters
    ----------
    times : dict
        Dictionary mapping model names to training times (seconds)
    """
    print("\nTraining times:")
    for name, t in times.items():
        print(f"  {name}: {t:.1f} seconds")

