"""
Configuration file for wavelet transform experiments.
All hyperparameters are centralized here for easy experimentation.
"""

import torch
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
import os

# =============================================================================
# DATASET CONFIGURATION
# =============================================================================
class DatasetConfig:
    # Data loading and preprocessing
    EMBEDDING_MODEL = 'all-MiniLM-L6-v2'  # Sentence transformer model
    MAX_TEXTS = 5000  # Max number of texts to embed (None for all)
    TRAIN_FRACTION = 0.4  # Fraction of data for training
    EVAL_FRACTION = 0.01   # Fraction of data for evaluation (test set)
    RANDOM_SEED = 42
    
    # Data processing
    STANDARDIZE_DATA = False  # Whether to center the data
    SCALE_DATA = False  # Whether to scale variance (usually False for embeddings)

# =============================================================================
# TRAINING CONFIGURATION
# =============================================================================
class TrainingConfig:
    BATCH_SIZE = 128
    N_PAIRS_CHECK = 1500  # Number of pairs for distance evaluation
    TARGET_A = 8.0  # Target parameter for exponential energy decay
    
    # Device and parallelization
    USE_DATA_PARALLEL = True  # Auto-detected based on available GPUs
    
# =============================================================================
# NEURAL NETWORK HYPERPARAMETERS
# =============================================================================
class NeuralNetworkConfig:
    # Architecture
    HIDDEN_DIMS = [512, 256, 128]
    
    # Training parameters
    LEARNING_RATE = 5e-5
    EPOCHS = 150
    
    # Loss weights for balanced approach
    GAMMA = 50.0   # Shape loss weight
    BETA = 50.0    # Distance loss weight

# =============================================================================
# WAVELET TRANSFORM HYPERPARAMETERS
# =============================================================================
class WaveletConfig:
    # Wavelet structure
    FILTER_LEN = 4
    N_LEVELS = 6

    # Training parameters
    LEARNING_RATE = 0.005
    EPOCHS = 100            # Increased from 10 to allow for LR decay
    
    # Learning rate decay
    LR_DECAY_STEPS = 50     # Decay LR every 50 epochs
    LR_DECAY_FACTOR = 0.7   # Multiply LR by 0.7 when decaying
    
    # Loss weights for balanced approach
    GAMMA = 12.5
    # GAMMA = 25   # Shape loss weight
    BETA = 25    # Distance loss weight

# =============================================================================
# ORTHOGONAL WAVELET HYPERPARAMETERS
# =============================================================================
class OrthogonalWaveletConfig:
    # Training parameters
    LEARNING_RATE = 2e-4
    EPOCHS = 300
    
    # # Loss weights
    GAMMA = 0.5   # Shape loss weight
    BETA = 5.0    # Distance loss weight 
    # uses same as forceorder

    LAMBDA = 5.0  # Orthogonality constraint weight

# =============================================================================
# FORCE ORDER LOSS CONFIGURATION
# =============================================================================
class ForceOrderConfig:
    # Loss weights
    GAMMA = 0.5 
    # GAMMA = 4  # Shape loss weight (keep low)
    BETA = 5.0   # Distance/ordering loss weight
    
    # Triplet loss parameters
    MARGIN = 0.1  # Margin for triplet ranking loss
    MAX_TRIPLETS = 256  # Maximum number of triplets to sample

# =============================================================================
# CORRELATION LOSS CONFIGURATION
# =============================================================================
class CorrelationConfig:
    # Loss weights
    GAMMA = 2   # Shape loss weight (keep low)
    BETA = 10.0   # Distance correlation loss weight
    
    # Correlation loss parameters
    MAX_PAIRS = 512  # Maximum number of pairs to sample for correlation

# =============================================================================
# EXPONENTIAL CORRELATION LOSS CONFIGURATION
# =============================================================================
class ExpCorrelationConfig:
    # Loss weights
    GAMMA = 2   # Shape loss weight (keep low)
    BETA = 10.0   # Exp distance correlation loss weight
    
    # Correlation loss parameters
    MAX_PAIRS = 512  # Maximum number of pairs to sample for correlation

# =============================================================================
# FORCE SHAPE LOSS CONFIGURATION
# =============================================================================
class ForceShapeConfig:
    # Loss weights
    GAMMA = 200.0  # Shape loss weight (high priority for compaction)
    BETA = 1.0     # Distance loss weight
    
    # Geometric decay parameters
    TARGET_RATIO = 0.88  # Target ratio for early dimensions
    MIDDLE_TARGET_RATIO = 0.93  # Target ratio for middle dimensions
    LATE_TARGET_RATIO = 0.98   # Target ratio for late dimensions
    
    # Dimension thresholds
    EARLY_DIM_THRESHOLD = 30   # Dimensions 1-30 are "early"
    MIDDLE_DIM_THRESHOLD = 100  # Dimensions 31-100 are "middle"
    
    # Weighting parameters
    DECAY_SCALE = 50.0  # Controls how quickly importance drops
    EARLY_BOOST = 2.0   # Extra weight for very early dimensions (1-20)
    INCREASING_PENALTY = 20.0  # Penalty for increasing energy ratios
    
    # Compaction parameters
    TARGET_FRACTION = 0.7  # Target fraction of energy in early dimensions
    EARLY_DIMS = 50       # Number of "early" dimensions for compaction
    COMPACTION_WEIGHT = 2.0  # Weight for compaction loss

# =============================================================================
# EVALUATION CONFIGURATION
# =============================================================================
class EvaluationConfig:
    # Metrics to compute
    COMPUTE_DISTANCE_PRESERVATION = True
    COMPUTE_ENERGY_ANALYSIS = True
    COMPUTE_GEOMETRIC_RATIOS = True
    COMPUTE_ORDERING_METRICS = True
    
    # Plotting
    PLOT_TRAINING_HISTORIES = True
    PLOT_ENERGY_CURVES = True
    SAVE_PLOTS = False  # Whether to save plots to files
    
    # Energy analysis points (as percentages of total dimensions)
    ENERGY_CHECK_PERCENTAGES = [10, 25, 50]  # Percentage of dimensions to check energy at

    # KNN recall parameters
    COMPUTE_KNN_RECALL = True
    K_FOR_RECALL = 10
    KNN_SAMPLE_FRACTION = 0.1

# =============================================================================
# OUTPUT CONFIGURATION
# =============================================================================
class OutputConfig:
    # File naming
    OUTPUT_PREFIX = "wavelet_experiment"
    TIMESTAMP_OUTPUT = True  # Add timestamp to output files
    
    # What to save
    SAVE_TRANSFORMED_DATA = True
    SAVE_EVALUATION_RESULTS = True
    SAVE_MODEL_WEIGHTS = False  # Usually not needed for analysis
    
    # File formats
    DATA_FORMAT = 'csv'  # 'csv', 'npz', or 'pickle'
    RESULTS_FORMAT = 'csv'  # Format for evaluation results

# =============================================================================
# EXPERIMENT PRESETS
# =============================================================================
class ExperimentPresets:
    """Predefined configurations for common experiments"""
    
    @staticmethod
    def get_quick_test():
        """Fast configuration for testing"""
        config = {
            'dataset': DatasetConfig(),
            'training': TrainingConfig(),
            'nn': NeuralNetworkConfig(),
            'wavelet': WaveletConfig(),
            'ortho_wavelet': OrthogonalWaveletConfig(),
            'force_order': ForceOrderConfig(),
            'force_shape': ForceShapeConfig(),
            'correlation': CorrelationConfig(),
            'exp_correlation': ExpCorrelationConfig(),
            'evaluation': EvaluationConfig(),
            'output': OutputConfig()
        }
        
        # Override for quick testing
        config['training'].BATCH_SIZE = 256
        config['nn'].EPOCHS = 50
        config['wavelet'].EPOCHS = 50
        config['ortho_wavelet'].EPOCHS = 30
        
        return config
    
    @staticmethod
    def get_high_quality():
        """High-quality configuration for final results"""
        config = {
            'dataset': DatasetConfig(),
            'training': TrainingConfig(),
            'nn': NeuralNetworkConfig(),
            'wavelet': WaveletConfig(),
            'ortho_wavelet': OrthogonalWaveletConfig(),
            'force_order': ForceOrderConfig(),
            'force_shape': ForceShapeConfig(),
            'correlation': CorrelationConfig(),
            'exp_correlation': ExpCorrelationConfig(),
            'evaluation': EvaluationConfig(),
            'output': OutputConfig()
        }
        
        # Override for high quality
        config['training'].BATCH_SIZE = 1024
        config['nn'].EPOCHS = 300
        config['wavelet'].EPOCHS = 300
        config['ortho_wavelet'].EPOCHS = 200
        
        return config
    
    @staticmethod
    def get_compaction_focused():
        """Configuration focused on energy compaction"""
        config = ExperimentPresets.get_high_quality()
        
        # Emphasize compaction
        config['force_shape'].GAMMA = 500.0
        config['force_shape'].TARGET_RATIO = 0.85
        config['force_shape'].TARGET_FRACTION = 0.8
        
        return config
    
    @staticmethod
    def get_small_train_large_eval():
        """Configuration with small training set and large evaluation set"""
        config = ExperimentPresets.get_high_quality()
        
        # Use small training set, large evaluation set
        config['dataset'].TRAIN_FRACTION = 0.2  # 20% for training
        config['dataset'].EVAL_FRACTION = 0.8   # 80% for evaluation
        
        return config

# =============================================================================
# DEFAULT CONFIGURATION
# =============================================================================
def get_default_config():
    """Get the default configuration"""
    return {
        'dataset': DatasetConfig(),
        'training': TrainingConfig(),
        'nn': NeuralNetworkConfig(),
        'wavelet': WaveletConfig(),
        'ortho_wavelet': OrthogonalWaveletConfig(),
        'force_order': ForceOrderConfig(),
        'force_shape': ForceShapeConfig(),
        'correlation': CorrelationConfig(),
        'exp_correlation': ExpCorrelationConfig(),
        'evaluation': EvaluationConfig(),
        'output': OutputConfig()
    }

@dataclass
class HyperparameterSearchConfig:
    """Configuration for hyperparameter search"""
    # Core hyperparameters to tune
    kernel_sizes: List[int] = None                   # Size of wavelet kernels
    num_layers: List[int] = None                     # Number of neural network layers
    learning_rates: List[float] = None               # Learning rates
    
    # Loss function parameters
    target_a_values: List[float] = None              # Exponential decay parameter
    gamma_beta_ratios: List[float] = None            # Ratio of gamma (shape) to beta (distance) loss weights
    base_loss_weights: List[float] = None            # Base loss weight (beta will use this, gamma = ratio * beta)
    
    # Training parameters to tune
    batch_sizes: List[int] = None
    epochs: List[int] = None
    early_stopping_patience: List[int] = None
    lr_decay_factors: List[float] = None
    lr_decay_steps: List[int] = None
    
    # Search configuration
    max_workers: int = 8                             # For parallel processing
    save_frequency: int = 10                         # Save results every N experiments
    
    def __post_init__(self):
        # Set defaults if not provided
        if self.kernel_sizes is None:
            self.kernel_sizes = [3, 5, 7, 9, 11]
        
        if self.num_layers is None:
            self.num_layers = [2, 3, 4, 5]
        
        if self.learning_rates is None:
            self.learning_rates = [0.001, 0.005, 0.01, 0.05]
        
        if self.target_a_values is None:
            self.target_a_values = [4.0, 6.0, 8.0, 10.0, 12.0]  # Exponential decay rates
        
        if self.gamma_beta_ratios is None:
            self.gamma_beta_ratios = [0.1, 0.5, 1.0, 2.0, 5.0]  # Shape/distance loss ratio
        
        if self.base_loss_weights is None:
            self.base_loss_weights = [25.0, 50.0, 100.0]  # Base weight for beta (distance loss)
        
        if self.batch_sizes is None:
            self.batch_sizes = [32, 64, 128, 256]
        
        if self.epochs is None:
            self.epochs = [100, 200, 300]
        
        if self.early_stopping_patience is None:
            self.early_stopping_patience = [10, 20, 30]
        
        if self.lr_decay_factors is None:
            self.lr_decay_factors = [0.5, 0.7, 0.9]
        
        if self.lr_decay_steps is None:
            self.lr_decay_steps = [50, 100, 150]

@dataclass
class VecFileConfig:
    """Configuration for reading .vec files (ANN benchmark format)"""
    file_path: str = ""
    max_vectors: int = -1          # -1 for all vectors
    normalize: bool = True         # Normalize vectors to unit length
    dtype: str = "float32"         # Data type for vectors
    
def create_hyperparam_search_config() -> HyperparameterSearchConfig:
    """Create default hyperparameter search configuration"""
    return HyperparameterSearchConfig()

def create_quick_hyperparam_search_config() -> HyperparameterSearchConfig:
    """Create quick hyperparameter search configuration for testing"""
    return HyperparameterSearchConfig(
        kernel_sizes=[5, 7],
        num_layers=[3, 4],
        learning_rates=[0.005, 0.01],
        target_a_values=[6.0, 8.0, 10.0],
        gamma_beta_ratios=[0.5, 1.0],
        base_loss_weights=[25.0, 50.0],
        batch_sizes=[64, 128],
        epochs=[50, 100],
        early_stopping_patience=[10, 15],
        lr_decay_factors=[0.7, 0.9],
        lr_decay_steps=[30, 50],
        max_workers=4
    ) 