import re
from warnings import warn
from operator import getitem
from typing import Any, Optional, Union
from dataclasses import dataclass
from omegaconf import OmegaConf, MISSING

OmegaConf.register_new_resolver("switch", getitem, use_cache=True)
OmegaConf.register_new_resolver("eval", eval, use_cache=True)

from _utils import using_multiple_devices


def get_configuration(
    base_struct,
    setting_fn: Union[callable, str] = None,
    setting_fn_module=None,  # the module where setting_fn is located
    cli_overwrites: list[str] = None,
    set_read_only=True,
):
    config = OmegaConf.structured(base_struct)
    OmegaConf.set_struct(config, True)

    if setting_fn:
        if isinstance(setting_fn, str):
            setting_fn = getattr(setting_fn_module, setting_fn)
        setting_fn(config)

    if cli_overwrites:
        overwrites = OmegaConf.from_cli(cli_overwrites)
        config = OmegaConf.merge(config, overwrites)
        if (
            using_multiple_devices(config.devices, config.num_nodes)
            and "batch_size" not in overwrites
        ):
            warn(
                "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
                + 'Multiple devices are used but "batch_size" is not overwritten. Be careful that "batch_size" should be set as targeted total batch size divided by number of devices.\n'
                + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
            )

    # Resolve interpolation before the configuration used by dataloader,
    # otherwise custom resolvers are accessed at every access and slow down the training a lot
    OmegaConf.resolve(config)
    OmegaConf.set_readonly(config, set_read_only)
    return config


@dataclass
class BaseConfiguration:
    # Model
    scale: str = MISSING
    embedding_size: int = "${switch:{small:128, base:768, large:1024},${scale}}"
    hidden_size: int = "${switch:{small:256, base:768, large:1024},${scale}}"
    intermediate_size: int = "${switch:{small:1024, base:3072, large:4096},${scale}}"
    num_attention_heads: int = "${switch:{small:4, base:12, large:16},${scale}}"
    num_hidden_layers: int = "${switch:{small:12, base:12, large:24},${scale}}"
    max_sequence_length: int = "${switch:{small:128, base:512, large:512},${scale}}"
    position_embedding_type: str = "absolute"
    dropout_prob: float = 0.1
    tokenizer: str = "google/electra-small-generator"  # Note bert/electra small/base/large uncased tokenizer are all the same (?)

    # Experimentation
    ## `pytorch_lightning.Trainer`'s accelerator related arguments
    devices: Any = MISSING  # int / list / str
    num_nodes: int = 1
    accelerator: str = "gpu"
    strategy: Optional[str] = None
    ## Other
    seed: Optional[int] = None
    do_testing: bool = False  # Set to False / True to do (training & validation) / testing
    logger: Optional[str] = None  # None / "wandb" / "tensorboard"
    ## Saving and loading weights
    load_ckpt_path: Optional[str] = None
    ### - None : don't load pretrained model
    ### - ".../xxx.ckpt" or "xxx.ckpt" : an relative (to "PROJECT_ROOT/checkpoints") path of a lightning checkpoint
    ### - ".../xxx.finetuning" or "xxx.finetuning" : an relative (to "PROJECT_ROOT/checkpoints") path of a directory where we automatically find the only one checkpoint pretrained on the current task. Useful when finding finetuned checkpoint for testing.
    ### - otherwise : try to load a hugginface model (served as `model_name_or_path` in `from_pretrained`)
    save_ckpt_path: Optional[str] = None
    ### - not None : a relative path to "PROJECT_ROOT/checkpoints"
    ### - None: infer checkpoint direcotry and checkpoint name
    ###   - checkpoint directory: next to loaded ckpt (see `get_trainer` in `_abastract_task/run.py`)
    ###       if `load_ckpt_path` is None : "<PROJECT_ROOT_DIR>/checkpoints"
    ###       else: the same path with `load_ckpt_path` but ends with ".finetuning" instead
    ###   - checkpoint name: the suggested checkpoint name passed to `get_trainer` will be used. (See <task>/run.py)

    # Data / Data loading
    datasets_cache_dir: Optional[str] = None
    datasets_num_proc: int = 1
    dataloader_num_workers: int = 3

    # Optimization
    num_steps: Optional[int] = None
    num_epochs: Optional[int] = None
    batch_size: int = MISSING
    weight_decay: float = MISSING
    gradient_clip_val: float = MISSING
    learning_rate: float = MISSING
    optimizer_eps: float = 1e-6
    optimizer_bias_correction: bool = False  # follow ELECTRA's impl.
    ## You can choose only one of warmup_steps/fraction, disabled if None
    lr_warmup_steps: Optional[int] = None
    lr_warmup_fraction: Optional[float] = None
    ## decay lr toward lower layers, disabled if None
    lr_layer_wise_decay_rate: Optional[float] = None
    lr_layer_wise_decay_original_way: bool = True
    mixed_precision_init_scale: Optional[float] = None
