"""Central configuration for Bayesian Transformer experiments."""

import os
import gc
import random
import numpy as np

# Enable determinism
os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

# Enable GPU memory growth BEFORE importing TensorFlow
# This prevents TF from grabbing all available GPU memory at startup
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import tensorflow as tf
import tensorflow_probability as tfp

# GPU & SESSION SETUP

def setup_gpu():
    """Configure GPU memory growth and determinism."""
    tf.keras.backend.clear_session()
    gc.collect()
    tf.config.optimizer.set_jit(False)

    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(f"GPU Memory Growth Enabled for {len(gpus)} GPU(s).")
        except RuntimeError as e:
            print(f"GPU Config Error: {e}")

    tf.keras.mixed_precision.set_global_policy("float32")
    tf.config.experimental.enable_op_determinism()


# GLOBAL SEEDING

SEED = 42


def set_seed(seed=SEED):
    """Set random seeds for reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    tf.keras.backend.clear_session()
    gc.collect()
    random.seed(seed)  # Python's random module
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.keras.utils.set_random_seed(seed)



# DATA CONFIGURATION

DATA_DIR = "processed_data_agnews"
BATCH_SIZE = 64
MAX_LEN = 64
VOCAB_SIZE = 30522
OOD_SLICE = 7600




# MODEL ARCHITECTURE


MODEL_CONFIG = {
    'd_model': 256,
    'n_layers': 4,
    'n_heads': 4,
    'd_ff': 512,
    'num_classes': 2,
}

LOWRANK_CONFIG = {
    'rank': 16,
}




# TRAINING CONFIGURATION

EPOCHS_BASELINE = 20
EPOCHS_BAYESIAN = 20
EPOCHS_ENSEMBLE = 20

LEARNING_RATE = 2e-4
LR_DECAY_ALPHA = 0.01
WEIGHT_DECAY = 0.01

N_ENSEMBLE_MEMBERS = 5


KL_WARMUP_EPOCHS = 5
KL_SCALE_ATTENTION = 0.01
KL_SCALE_EMBEDDING = 0.001


# KL DIVERGENCE

def compute_kl_weight(num_train_samples, batch_size=BATCH_SIZE):
    """Compute KL weight for Bayesian models."""
    num_batches = int(np.ceil(num_train_samples / batch_size))
    return 0.1 / num_batches


# EVALUATION CONFIGURATION

N_MC_SAMPLES = 200
N_ECE_BINS = 15
IMPROVED_N_MC_SAMPLES = 100

DROPOUT_RATE_BASELINE = 0.0
DROPOUT_RATE_BAYESIAN = 0.0



# OUTPUT PATHS

OUTPUT_DIR = "outputs"
METRICS_CSV = "results_csv/bayesian_transformer_metrics.csv"
METRICS_CSV_WITH_ENSEMBLE = "results_csv/bayesian_transformer_metrics_with_ensemble.csv"
CONVERGENCE_PLOT = "figures/convergence_comparison.png"
MODEL_COMPARISON_PLOT = "figures/model_comparison_with_ensemble.png"
UNCERTAINTY_DIST_PLOT = "figures/uncertainty_distributions_with_ensemble.png"

tfd = tfp.distributions


def initialize():
    """Initialize environment with GPU setup and seeding."""
    setup_gpu()
    set_seed()
    print(f"\nTensorFlow version: {tf.__version__}")
    print(f"TensorFlow Probability version: {tfp.__version__}")
    return tfd
