from omegaconf import OmegaConf, DictConfig
import yaml

import os
from typing import Optional, Any
from dataclasses import dataclass

def flatten_trainer_config(cfg: DictConfig) -> dict:
    """Flatten nested config into a single dictionary."""
    flat_dict = {}
    
    # Convert DictConfig to regular dict and flatten
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)
    
    # Update with each config section
    flat_dict.update(cfg_dict.get('training', {}))
    flat_dict.update(cfg_dict.get('model', {}))
    
    flat_dict['device'] = cfg_dict.get('device', None)

    return flat_dict


@dataclass
class ROIDatasetConfig:
    sourcedir: str = './data'
    dataset: str = 'ukb-rest'
    roi: str = 'schaefer200'
    initial_noise: Optional[int] = 10
    temporal_resolution: Optional[float] = 0.735
    target_feature: Optional[str] = None
    regression: bool = False
    dynamic_length: Optional[int] = 480
    selected_disorder: Optional[str] = None
    add_task_dataset: bool = False


@dataclass
class TrainingConfig:
    # Random seeds
    random_seed: int = 42
    data_random_seed: int = 0

    # Dataset configuration
    test_ratio: float = 0.2
    num_workers: int = 8
    pin_memory: bool = True

    # Training configuration
    num_epochs: int = 100
    num_eval_epochs: int = 1
    lr: float = 1e-3
    wd: int = 0
    batch_size: int = 128
    gpu: int = 0
    save_only_best: bool = True


@dataclass
class ModelConfig:

    encoder: Optional[str] = None
    n_head: Optional[int] = None
    multiplier_feedforward: Optional[int] = None

    # Model configuration
    num_basis: int = 20
    state_dim: int = 128
    rep_dim: int = 128
    n_layer: int = 2
    n_layer_encoder: int = 2
    n_layer_decoder: int = 2
    n_layer_control: int = 2
    out_dim: int = 200    

    drop_out: float = 0.0
    ts: float = 0.02
    lamda_1: float = 1e-6
    lamda_2: float = 1e-8
    init_sigma: float = 10.0


# will be removed
@dataclass
class PretrainConfig: # ukb-rest, schaefer200
    # Random seeds
    random_seed: int = 42
    data_random_seed: int = 0

    # Dataset configuration
    test_ratio: float = 0.2
    num_workers: int = 8
    pin_memory: bool = True

    # Training configuration
    num_epochs: int = 100
    num_eval_epochs: int = 1
    lr: float = 1e-3
    wd: int = 0
    batch_size: int = 128
    gpu: int = 0
    save_only_best: bool = True
    
    # Model configuration
    num_basis: int = 20
    state_dim: int = 128
    rep_dim: int = 128
    n_layer: int = 2
    n_layer_encoder: int = 2
    n_layer_decoder: int = 2
    n_layer_control: int = 2
    out_dim: int = 200    

    drop_out: float = 0.0
    ts: float = 0.02
    lamda_1: float = 1e-6
    lamda_2: float = 1e-8
    init_sigma: float = 10.0




@dataclass
class DownstreamTrainingConfig:
    # Random seeds
    random_seed: int = 42
    data_random_seed: int = 0

    # Dataset configuration
    test_ratio: float = 0.2
    val_ratio: float = 0.2
    num_workers: int = 8
    pin_memory: bool = True

    # Training configuration
    num_epochs: int = 100
    num_eval_epochs: int = 1
    lr: float = 1e-3
    wd: int = 0
    batch_size: int = 128
    gpu: int = 0


@dataclass
class DownstreamModelConfig:
    # Random seeds
    num_layers: int = 2
    dim_hidden: int = 128
    output_size: int = 1
    dropout_rate: float = 0.
    reset: bool = False
