from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Dict, Union, List


@dataclass
class FinetuneConfig:
    # fmt: off
    vla_path: str = "openvla/openvla-7b"             # Path to OpenVLA model (on HuggingFace Hub or stored locally)
    seed: int = 42                                    # Random seed for reproducibility
    # Dataset
    data_root_dirs: List[str] = field(default_factory=list)  # Directory containing RLDS datasets
    eval_data_root_dirs: List[str] = field(default_factory=list)  # Directory containing RLDS datasets for evaluation
    dataset_name: str = "aloha_scoop_x_into_bowl"    # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
    run_root_dir: Path = Path("runs")                # Path to directory to store logs & checkpoints
    shuffle_buffer_size: int = 100_000               # Dataloader shuffle buffer size (can reduce if OOM errors occur)
    use_dummy_dataset: bool = False                  # If True, uses dummy dataset for testing
    from_scratch: bool = False                       # If True, trains from scratch

    # Algorithm and architecture
    use_l1_regression: bool = True                   # If True, trains continuous action head with L1 regression objective
    use_diffusion: bool = False                      # If True, trains continuous action head with diffusion modeling objective (DDIM)
    num_diffusion_steps_train: int = 50              # (When `diffusion==True`) Number of diffusion steps used for training
    use_film: bool = False                           # If True, uses FiLM to infuse language inputs into visual features
    num_images_in_input: int = 1                     # Number of images in the VLA input (default: 1)
    use_proprio: bool = False                        # If True, includes robot proprioceptive state in input
    targeting_strategy: str = "none"                  # "dot" or "mask" or "heatmap" or "none"

    # Training configuration
    batch_size: int = 4                              # Batch size per device (total batch size = batch_size * num GPUs)
    learning_rate: float = 5e-4                      # Learning rate
    lr_warmup_steps: int = 0                         # Number of steps to warm up learning rate (from 10% to 100%)
    num_steps_before_decay: int = 20_000            # Number of steps before LR decays by 10x
    grad_accumulation_steps: int = 1                 # Number of gradient accumulation steps
    max_steps: int = 30_000                         # Max number of training steps
    use_val_set: bool = False                        # If True, uses validation set and log validation metrics
    val_freq: int = 10                           # (When `use_val_set==True`) Validation set logging frequency in steps
    val_time_limit: int = 180                        # (When `use_val_set==True`) Time limit for computing validation metrics
    save_freq: int = 500                          # Checkpoint saving frequency in steps
    save_latest_checkpoint_only: bool = False        # If True, saves only 1 checkpoint, overwriting latest checkpoint
                                                     #   (If False, saves all checkpoints)
    resume: bool = True                             # If True, resumes from checkpoint
    resume_step: Optional[int] = None                # (When `resume==True`) Step number that we are resuming from
    image_aug: bool = True                           # If True, trains with image augmentations (HIGHLY RECOMMENDED)
    diffusion_sample_freq: int = 50                  # (When `use_diffusion==True`) Frequency for sampling in steps

    # LoRA
    use_lora: bool = True                            # If True, uses LoRA fine-tuning
    lora_rank: int = 32                              # Rank of LoRA weight matrix
    lora_dropout: float = 0.0                        # Dropout applied to LoRA weights
    merge_lora_during_training: bool = False          # If True, merges LoRA weights and saves result during training
                                                     #   Note: Merging can be very slow on some machines. If so, set to
                                                     #         False and merge final checkpoint offline!

    # Logging
    wandb_entity: str = ""          # Name of WandB entity
    wandb_project: str = "openvla_oft"        # Name of WandB project
    run_id_note: Optional[str] = None                # Extra note to add to end of run ID for logging
    run_id_override: Optional[str] = None            # Optional string to override the run ID with
    wandb_log_freq: int = 10                         # WandB logging frequency in steps

    # fmt: on
def create_config_with_overrides(overrides: Dict) -> FinetuneConfig:
    """
    Create a FinetuneConfig with specific overrides.
    
    Args:
        overrides: Dictionary of parameters to override in the base config
        
    Returns:
        FinetuneConfig with the specified overrides
    """
    config = FinetuneConfig()
    for key, value in overrides.items():
        if hasattr(config, key):
            setattr(config, key, value)
        else:
            raise ValueError(f"Invalid config parameter: {key}")
    return config


configs = {
    "default": FinetuneConfig(),

    "test": create_config_with_overrides({
        "dataset_name": "testing",
        "use_film": True,
        "num_images_in_input": 3,
        "use_val_set": True,
        "val_freq": 10,
        "use_dummy_dataset": True,
    }),
    ########### SUTURING ###########
    "suturing_final_datasets_lr_5e5": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 3,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
    }),
    "suturing_final_datasets_lr_5e5_mask": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_mask",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 4,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 250,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "mask",
        "batch_size": 2,
        "grad_accumulation_steps": 2,
    }),
    "suturing_final_datasets_lr_5e5_dot": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_dot",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 3,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "dot",
    }),
    "suturing_final_datasets_lr_5e5_heatmap": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_heatmap",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 4,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "heatmap",
        "batch_size": 2,
        "grad_accumulation_steps": 2,
    }),
    "suturing_final_datasets_lr_5e5_mask_update": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_mask_update",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 4,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "mask",
        "batch_size": 2,
        "grad_accumulation_steps": 2,
    }),
    "suturing_final_datasets_lr_5e5_mask_normalized": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_mask_normalized_dist",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 4,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "mask",
        "batch_size": 2,
        "grad_accumulation_steps": 2,
    }),
    "suturing_final_datasets_lr_5e5_dot_normalized": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_dot_normalized_dist",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 3,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "dot",
    }),
    "suturing_final_datasets_lr_5e5_heatmap_normalized": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_heatmap_normalized_dist",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 4,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "heatmap",
        "batch_size": 2,
        "grad_accumulation_steps": 2,
    }),
    "suturing_all_data_pretrain_lr_5e5": create_config_with_overrides({
        "run_id_override": "suturing_all_data_pretrain_lr_5e5_dist",
        "dataset_name": "tissue_lift_electro_pickuphandover_chole_knot_tying",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 3,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "none",
    }),
    "suturing_final_datasets_lr_5e5_dot_scratch": create_config_with_overrides({
        "run_id_override": "suturing_final_datasets_lr_5e5_dot_scratch_dist",
        "dataset_name": "suturing_1-9_w_no_throw_for_1-2",
        "data_root_dirs": [
            "suturing_training_data"
        ],
        "eval_data_root_dirs": [
            "suturing_eval_data"
        ],
        "use_film": True,
        "num_images_in_input": 3,
        "use_val_set": True,
        "val_freq": 10,
        "save_freq": 500,
        "lr_warmup_steps": 100,
        "learning_rate": 5e-5,
        "targeting_strategy": "dot",
        "from_scratch": True,
    }),
}


def get_config_by_name(name: str) -> FinetuneConfig:
    """
    Get a predefined FinetuneConfig by name.
    
    Args:
        name: Name of the config to retrieve
        
    Returns:
        FinetuneConfig with the specified settings
        
    Raises:
        ValueError: If the specified config name is not found
    """
    
    if name in configs:
        return configs[name]
    else:
        raise ValueError(f"Config '{name}' not found. Available configs: {', '.join(configs.keys())}")
    
