"""Configuration for Multi-Seed Robustness experiments.

This is a copy of config.py with modified settings for faster training:
- EPOCHS_BASELINE: 5 (was 20)
- EPOCHS_ENSEMBLE: 5 (was 20)
- EPOCHS_BAYESIAN: 20 (unchanged - Bayesian models need more epochs)
"""

import os
import gc
import random
import numpy as np

# Enable determinism
os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

import tensorflow as tf
import tensorflow_probability as tfp

# =============================================================================
# GPU & SESSION SETUP
# =============================================================================

def setup_gpu(memory_limit_mb=30*1024):
    """
    Configure GPU memory growth and determinism.

    Parameters:
    -----------
    memory_limit_mb : int
        Maximum GPU memory in MB (default: 30GB = 30*1024 MB)
        Set to None to disable limit and use memory growth only
    """
    tf.keras.backend.clear_session()
    gc.collect()
    tf.config.optimizer.set_jit(False)

    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                if memory_limit_mb is not None:
                    # Set hard memory limit to prevent hoarding all available RAM
                    tf.config.set_logical_device_configuration(
                        gpu,
                        [tf.config.LogicalDeviceConfiguration(memory_limit=memory_limit_mb)]
                    )
                    print(f"GPU Memory Limited to {memory_limit_mb/1024:.1f} GB.")
                else:
                    # Just enable memory growth (no hard limit)
                    tf.config.experimental.set_memory_growth(gpu, True)
                    print(f"GPU Memory Growth Enabled for {len(gpus)} GPU(s).")
        except RuntimeError as e:
            print(f"GPU Config Error: {e}")

    tf.keras.mixed_precision.set_global_policy("float32")
    tf.config.experimental.enable_op_determinism()


# =============================================================================
# GLOBAL SEEDING
# =============================================================================

SEED = 42


def set_seed(seed=SEED):
    """Set random seeds for reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    tf.keras.backend.clear_session()
    gc.collect()
    random.seed(seed)  # Python's random module
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.keras.utils.set_random_seed(seed)


# =============================================================================
# DATA CONFIGURATION
# =============================================================================

DATA_DIR = "processed_data_agnews"
BATCH_SIZE = 64
MAX_LEN = 64
VOCAB_SIZE = 30522
OOD_SLICE = 7600



# =============================================================================
# MODEL ARCHITECTURE
# =============================================================================

MODEL_CONFIG = {
    'd_model': 256,
    'n_layers': 4,
    'n_heads': 4,
    'd_ff': 512,
    'num_classes': 2,
}

LOWRANK_CONFIG = {
    'rank': 16,
}




# TRAINING CONFIGURATION (MODIFIED FOR SEED EXPERIMENTS)

EPOCHS_BASELINE = 20  # Reduced from 20
EPOCHS_BAYESIAN = 20  # Keep full training for Bayesian
EPOCHS_ENSEMBLE = 20  # Reduced from 20

LEARNING_RATE = 2e-4
LR_DECAY_ALPHA = 0.01
WEIGHT_DECAY = 0.01

N_ENSEMBLE_MEMBERS = 5


KL_WARMUP_EPOCHS = 5
KL_SCALE_ATTENTION = 0.01
KL_SCALE_EMBEDDING = 0.001


# KL DIVERGENCE

def compute_kl_weight(num_train_samples, batch_size=BATCH_SIZE):
    """Compute KL weight for Bayesian models."""
    num_batches = int(np.ceil(num_train_samples / batch_size))
    return 0.1 / num_batches


# EVALUATION CONFIGURATION


N_MC_SAMPLES = 100
N_ECE_BINS = 15
IMPROVED_N_MC_SAMPLES = 100

DROPOUT_RATE_BASELINE = 0.0
DROPOUT_RATE_BAYESIAN = 0.0


# OUTPUT PATHS


OUTPUT_DIR = "outputs"
METRICS_CSV = "results_csv/bayesian_transformer_metrics.csv"
METRICS_CSV_WITH_ENSEMBLE = "results_csv/bayesian_transformer_metrics_with_ensemble.csv"
CONVERGENCE_PLOT = "figures/convergence_comparison.png"
MODEL_COMPARISON_PLOT = "figures/model_comparison_with_ensemble.png"
UNCERTAINTY_DIST_PLOT = "figures/uncertainty_distributions_with_ensemble.png"



tfd = tfp.distributions


# INITIALIZATION

def initialize():
    """Initialize environment with GPU setup and seeding."""
    setup_gpu()
    set_seed()
    print(f"\nTensorFlow version: {tf.__version__}")
    print(f"TensorFlow Probability version: {tfp.__version__}")
    return tfd
