"""
Configuration management system for TensorGalerkin
"""

import os
import yaml
import json
import toml
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, Any, List, Tuple, Union
from pathlib import Path


@dataclass
class BaseConfig:
    """Base configuration class with common parameters"""
    
    # Device and computation settings
    device: str = "cpu"
    seed: int = 42
    verbose: bool = False
    
    # Paths and logging
    ckpt_path: str = "checkpoints/model.pt"
    log_dir: str = "logs"
    use_tensorboard: bool = False
    
    # Additional data for tracking
    datarow: Dict[str, Any] = field(default_factory=dict)
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert config to dictionary"""
        return asdict(self)
    
    def save(self, path: str) -> None:
        """Save configuration to file"""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        
        config_dict = self.to_dict()
        
        if path.suffix == '.yaml' or path.suffix == '.yml':
            with open(path, 'w') as f:
                yaml.dump(config_dict, f, default_flow_style=False)
        elif path.suffix == '.json':
            with open(path, 'w') as f:
                json.dump(config_dict, f, indent=2)
        elif path.suffix == '.toml':
            with open(path, 'w') as f:
                toml.dump(config_dict, f)
        else:
            raise ValueError(f"Unsupported config format: {path.suffix}")
    
    @classmethod
    def load(cls, path: str) -> 'BaseConfig':
        """Load configuration from file"""
        path = Path(path)
        
        if path.suffix == '.yaml' or path.suffix == '.yml':
            with open(path, 'r') as f:
                config_dict = yaml.safe_load(f)
        elif path.suffix == '.json':
            with open(path, 'r') as f:
                config_dict = json.load(f)
        elif path.suffix == '.toml':
            with open(path, 'r') as f:
                config_dict = toml.load(f)
        else:
            raise ValueError(f"Unsupported config format: {path.suffix}")
        
        return cls(**config_dict)


@dataclass
class MeshConfig(BaseConfig):
    """Configuration for mesh generation"""
    
    # Grid parameters
    grid: int = 32
    nx: int = 32
    ny: int = 32
    
    # Geometry parameters
    element: str = "quad"  # "quad" or "tri"
    shape: str = "rectangle"  # "rectangle", "circle", "ellipse", etc.
    xlims: Tuple[float, float] = (0.0, 1.0)
    ylims: Tuple[float, float] = (0.0, 1.0)
    
    # Element connectivity
    use_dense_element: bool = False
    
    # Boundary conditions
    use_free_boundary: bool = False
    boundary_value: float = 0.0
    
    # Mesh refinement (for gmsh-based meshes)
    chara_length: float = 0.1
    radius: Optional[Union[float, List[float]]] = None
    center: Optional[Tuple[float, float]] = None


@dataclass
class ModelConfig(BaseConfig):
    """Configuration for neural network models"""
    
    # Architecture parameters
    gnn: str = "gcn"  # "gcn", "gat", "sage", "sgc", etc.
    n_hidden: int = 64
    n_layers: int = 3
    
    # Regularization
    dropout_in: float = 0.0
    dropout: float = 0.0
    
    # Input/output features
    use_coord_feat: bool = False
    input_dim: Optional[int] = None
    output_dim: Optional[int] = None
    
    # Advanced model features
    activation: str = "relu"
    use_input_norm: bool = True
    use_bn: bool = False
    use_res: bool = False
    
    # Multi-stage training
    n_boost: int = 1
    use_boost_scale: bool = False


@dataclass
class TrainingConfig(BaseConfig):
    """Configuration for training parameters"""
    
    # Optimizer settings
    optimizer: str = "adam"  # "adam", "lbfgs", "combine"
    lr: float = 0.001
    
    # Training schedule
    epoch: int = 100
    batch_size: int = 32
    max_iter: int = 1000
    
    # Scheduler settings
    use_scheduler: bool = False
    scheduler: str = "step"
    scheduler_step_size: int = 50
    scheduler_gamma: float = 0.5
    
    # Loss and evaluation
    loss_scale: float = 1.0
    eval_every_eps: int = 1
    discount_factor: float = 1.0
    
    # Data parameters
    n_samples: int = 256
    validation_ratio: float = 0.2
    
    # Early stopping
    patience: int = 50
    min_delta: float = 1e-6


@dataclass
class EquationConfig(BaseConfig):
    """Configuration for equation parameters"""
    
    # Time parameters
    dt: float = 0.01
    t: float = 1.0
    
    # Equation-specific parameters
    a: float = 1.0  # Diffusion coefficient for Poisson/Heat
    c: float = 1.0  # Wave speed for wave equation
    
    # Prediction parameters
    predict_steps: int = 1
    window_size: int = 1
    
    # Physics-informed training
    use_fem_start: bool = False
    fast_residual: bool = True
    
    # Source function parameters
    use_analytical_dataset: bool = False
    K: int = 1  # Number of modes for analytical solutions
    n_samples: int = 256


@dataclass
class PoissonConfig(MeshConfig, ModelConfig, TrainingConfig, EquationConfig):
    """Complete configuration for Poisson equation training"""
    
    def __post_init__(self):
        """Post-initialization to set derived parameters"""
        # Set input/output dimensions based on problem
        if self.input_dim is None:
            self.input_dim = 3 if self.use_coord_feat else 1
        if self.output_dim is None:
            self.output_dim = 1
        
        # Set device to cuda if available
        if self.device == "auto":
            import torch
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Create checkpoint directory
        os.makedirs(os.path.dirname(self.ckpt_path), exist_ok=True)


def create_default_configs():
    """Create default configuration files for different equation types"""
    
    # Create configs directory
    config_dir = Path("configs")
    config_dir.mkdir(exist_ok=True)
    
    # Poisson equation config
    poisson_config = PoissonConfig(
        # Model settings
        gnn="gcn",
        n_hidden=64,
        n_layers=3,
        
        # Training settings
        optimizer="adam",
        lr=0.001,
        epoch=100,
        batch_size=16,
        
        # Mesh settings
        grid=32,
        element="quad",
        
        # Equation settings
        a=1.0,
        
        # Paths
        ckpt_path="checkpoints/poisson_model.pt",
        log_dir="logs/poisson"
    )
    
    poisson_config.save(config_dir / "poisson_default.yaml")
    print(f"Created default Poisson config: {config_dir / 'poisson_default.yaml'}")
    
    return poisson_config


if __name__ == "__main__":
    # Create default configurations
    create_default_configs()