"""
Configuration Settings for Bayesian LSTM

This module contains all configuration settings to ensure reproducibility.
"""

import os
import gc
import random
import json
import pickle
from datetime import datetime
from pathlib import Path
import numpy as np

# TensorFlow is imported inside configure_environment() to avoid
# GPU initialization before environment variables are set.



SEED = 42


def configure_environment():
    """
    Configure TensorFlow and environment settings for reproducibility.

    Matches the working BayesLSTM_clean.ipynb configuration exactly.
    """
    # Import TensorFlow first (same as working notebook)
    import tensorflow as tf
    import tensorflow_probability as tfp

    print(f"TensorFlow version: {tf.__version__}")
    print(f"TensorFlow Probability version: {tfp.__version__}")

    # Environment settings (same order as working notebook)
    os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    os.environ['PYTHONHASHSEED'] = str(SEED)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

    tf.config.optimizer.set_jit(False)
    tf.keras.mixed_precision.set_global_policy("float32")
    tf.keras.utils.set_random_seed(SEED)
    tf.config.experimental.enable_op_determinism()

    np.random.seed(SEED)
    random.seed(SEED)
    tf.random.set_seed(SEED)

    for g in tf.config.list_physical_devices("GPU"):
        tf.config.experimental.set_memory_growth(g, True)

    print(f"Random seed configured: {SEED}")
    print("Deterministic ops: ENABLED")
    set_seed(SEED)


def set_seed(seed=SEED):
    """
    Set all random seeds for reproducibility.
    """
    import tensorflow as tf
    os.environ["PYTHONHASHSEED"] = str(seed)
    tf.keras.backend.clear_session()
    gc.collect()
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.keras.utils.set_random_seed(seed)


# ==============================================================================
# Model Configuration
# ==============================================================================

# LSTM Architecture
LSTM_HIDDEN = 64
NUM_LAYERS = 2
INPUT_SIZE = 15  # Number of features
SEQUENCE_LENGTH = 24  # Number of timesteps

# Training Configuration
BATCH_SIZE = 64
EPOCHS = 150
LEARNING_RATE = 1e-3
KL_WARMUP_EPOCHS = 0
KL_SCALE = 0.25

# Low-Rank Configuration
RANK = [14, 20]

# Deep Ensemble Configuration
ENSEMBLE_SIZE = 5
ENSEMBLE_EPOCHS = 100

# Evaluation Configuration
NUM_MC_SAMPLES = 250
CONFIDENCE_LEVEL = 0.95

# Data Directory
DATA_DIR = "beijing_data"


# ==============================================================================
# Model Names and Colors (for visualization)
# ==============================================================================

MODEL_NAMES = [
    'Deterministic',
    'Full-Rank Bayesian',
    'Low-Rank Bayesian',
    f"Low-Rank Bayesian (SVD Init, r={RANK})",
    'Rank-1 Bayesian',
    'Deep Ensemble',
]

MODEL_COLORS = {
    'Deterministic': '#1f77b4',
    'Full-Rank Bayesian': '#ff7f0e',
    'Low-Rank Bayesian': '#2ca02c',
    f"Low-Rank Bayesian (SVD Init, r={RANK})": '#8c564b',
    'Rank-1 Bayesian': '#9467bd',
    'Deep Ensemble': '#d62728',
}

# For bar plots (list form)
MODEL_COLORS_LIST = ['#1f77b4', '#ff7f0e', '#2ca02c', '#8c564b', '#9467bd', '#d62728']


# ==============================================================================
# Prior Parameters for Bayesian Models
# ==============================================================================

def get_prior_params():
    """Get prior parameters. Called after TF is configured."""
    import tensorflow as tf
    return {
        'pi': 0.5,
        'sigma1': 1.0,
        'sigma2': tf.exp(-6.0),
    }


# ==============================================================================
# Retention Rates for Selective Prediction
# ==============================================================================

RETENTION_RATES = [1.0, 0.95, 0.90, 0.85, 0.80, 0.75, 0.70]


# ==============================================================================
# Multi-Seed Experiment Tracking
# ==============================================================================

DEFAULT_SEEDS = [42, 123, 456, 2026]


class ExperimentTracker:
    """
    Tracks experiment progress and manages checkpoints for multi-seed experiments.

    Enables resumable experiments by saving progress after each completed run.

    Args:
        checkpoint_dir: Directory to store progress and results files
        seeds: List of seeds being used (for summary calculations)
        model_names: List of model names being tested (for summary calculations)

    Example:
        tracker = ExperimentTracker(
            checkpoint_dir='multi_seed_results',
            seeds=[42, 123, 456],
            model_names=['Deterministic', 'Full-Rank Bayesian', 'Low-Rank Bayesian']
        )

        for seed in seeds:
            for model_name in model_names:
                if tracker.is_completed(seed, model_name):
                    continue
                # Run experiment...
                tracker.save_result(seed, model_name, metrics, training_time)
    """

    def __init__(self, checkpoint_dir, seeds=None, model_names=None):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.progress_file = self.checkpoint_dir / "progress.json"
        self.results_file = self.checkpoint_dir / "results.pkl"
        self.seeds = seeds or DEFAULT_SEEDS
        self.model_names = model_names or []
        self.load_progress()

    def load_progress(self):
        """Load existing progress or initialize new tracking."""
        if self.progress_file.exists():
            with open(self.progress_file, 'r') as f:
                self.progress = json.load(f)
            print(f"\nResuming: {len(self.progress['completed'])} experiments already completed")
        else:
            self.progress = {
                'completed': [],
                'start_time': datetime.now().isoformat(),
                'timings': {},
                'reused_seed_42': []
            }

        # Load results
        if self.results_file.exists():
            with open(self.results_file, 'rb') as f:
                self.results = pickle.load(f)
        else:
            self.results = {}

    def is_completed(self, seed, model_name):
        """Check if experiment is already completed."""
        key = f"{seed}_{model_name}"
        return key in self.progress['completed']

    def mark_reused(self, seed, model_name):
        """Mark that existing model was reused."""
        key = f"{seed}_{model_name}"
        if key not in self.progress['reused_seed_42']:
            self.progress['reused_seed_42'].append(key)

    def save_result(self, seed, model_name, metrics, training_time, reused=False):
        """Save experiment result and update checkpoint."""
        key = f"{seed}_{model_name}"

        # Store result
        self.results[key] = {
            'seed': seed,
            'model': model_name,
            'metrics': metrics,
            'training_time': training_time,
            'reused': reused,
            'timestamp': datetime.now().isoformat()
        }

        # Update progress
        if key not in self.progress['completed']:
            self.progress['completed'].append(key)
        self.progress['timings'][key] = training_time

        # Save checkpoint
        with open(self.progress_file, 'w') as f:
            json.dump(self.progress, f, indent=2)

        with open(self.results_file, 'wb') as f:
            pickle.dump(self.results, f)

        status = "Reused" if reused else "Trained"
        print(f"{status}: {key} ({training_time:.1f}s)")

    def get_summary(self):
        """Get summary of completed experiments."""
        total = len(self.seeds) * len(self.model_names)
        completed = len(self.progress['completed'])
        return {
            'total': total,
            'completed': completed,
            'remaining': total - completed,
            'progress_pct': (completed / total) * 100 if total > 0 else 0,
            'reused': len(self.progress.get('reused_seed_42', []))
        }

    def print_summary(self):
        """Print a formatted summary of experiment progress."""
        summary = self.get_summary()
        print(f"\n{'='*60}")
        print(f"Experiment Progress")
        print(f"{'='*60}")
        print(f"  Completed: {summary['completed']}/{summary['total']} ({summary['progress_pct']:.1f}%)")
        print(f"  Remaining: {summary['remaining']}")
        print(f"  Reused (seed 42): {summary['reused']}")
        print(f"{'='*60}\n")
