"""Optimization setups."""

from dataclasses import dataclass, asdict
from typing import Optional


def training_strategy(strategy, lr=None, epochs=None, dryrun=False):
    """Parse training strategy."""
    if strategy.lower() == 'darts':
        defs = DARTS
    else:
        raise ValueError('Unknown training strategy.')

    # Overwrite some hyperparameters if desired
    if epochs is not None:
        defs.epochs = epochs
    if lr is not None:
        defs.param_lr = lr
    defs.dryrun = dryrun
    defs.validate = 1
    return defs


@dataclass
class Hyperparameters:
    """Default usual parameters, not intended for parsing."""

    name: str

    epochs: int
    param_optimizer: str
    alpha_optimizer: str
    param_lr: float
    alpha_lr: float
    param_scheduler: str
    alpha_scheduler: str
    alpha_warmup: bool
    param_warmup: bool
    param_weight_decay: float
    alpha_weight_decay: float

    dryrun: bool

    project: bool
    update: str
    mirror_descent: bool
    decay_entropy: bool
    fuse: bool
    norm: Optional

    def asdict(self):
        return asdict(self)


DARTS = Hyperparameters(
    name='DARTS',

    epochs=25,
    param_optimizer='SGD',
    alpha_optimizer='GD',
    param_lr=0.025,
    alpha_lr=1e-2,
    param_scheduler='annealing',
    alpha_scheduler='none',
    param_warmup=False,
    alpha_warmup=False,
    param_weight_decay=3e-4,
    alpha_weight_decay=3e-4,

    dryrun=False,

    project=False,
    update='liu',
    mirror_descent=False,
    decay_entropy=False,
    fuse=False,
    norm='DARTS',

)
