from dataclasses import dataclass, field
from functools import partial
from typing import Any, Optional
from _abstract_task.configuration import BaseConfiguration, get_configuration, MISSING


@dataclass
class GLUEConfiguration(BaseConfiguration):
    # Experimentation
    task: str = MISSING
    datasets_cache_dir: Optional[str] = None

    # Optimization
    batch_size: int = 32
    weight_decay: float = 0.0
    gradient_clip_val: float = 1.0
    learning_rate: float = "${switch:{small:3e-4, base:1e-4, large:5e-5},${scale}}"
    lr_warmup_fraction: float = 0.1
    lr_layer_wise_decay_rate: float = "${switch:{small:0.8, base:0.8, large:0.9},${scale}}"
    num_epochs: Optional[int] = MISSING

    # Data
    double_unordered: bool = True


glue_config = partial(get_configuration, base_struct=GLUEConfiguration)

NUM_EPOCHS = {
    "rte": 10,
    "cb": 10,
    "copa": 3,
    "multirc": 3,
    "wic": 3,
    "wsc": 3,
    "boolq": 10,
    "record": 3,
    "axb": "null",
    "axg": "null",
}