import sys
from dataclasses import dataclass
from functools import partial
from typing import Optional
from omegaconf import MISSING
from _abstract_task.configuration import BaseConfiguration, get_configuration


@dataclass
class PretrainingConfiguration(BaseConfiguration):

    # Data
    datasets: list[str] = ("openwebtext",)
    mask_probability: Optional[float] = None
    replace_probability: Optional[float] = None
    original_probability: Optional[float] = None

    # Optimization
    weight_decay: float = 0.01
    gradient_clip_val: float = 1.0
    lr_warmup_steps: int = 10000

    # Model
    use_electra: bool = MISSING
    electra_generator_size_divisor: int = "${switch:{small:4, base:3, large:4},${scale}}"

    # Text Structure Prediction
    tsp_loss_weight: Optional[float] = None
    tsp_shuffling_fixing_prob: Optional[float] = 0.85
    tsp_without_hierarchy: bool = False
    tsp_without_order: bool = False
    tsp_without_paragraph: bool = False

    # Sequence Classification
    inter_segment_task: Optional[str] = None  # None, "nsp", "sop", "sso"


pretraining_config = partial(
    get_configuration,
    base_struct=PretrainingConfiguration,
    setting_fn_module=sys.modules[__name__],  # this module
)

# =====================
# Base
# =====================


def mlm(config):
    config.use_electra = False
    config.mask_probability = (
        "${eval: ${switch:{small:0.15, base:0.15, large:0.15},${scale}} * 0.8}"
    )
    config.replace_probability = (
        "${eval: ${switch:{small:0.15, base:0.15, large:0.15},${scale}} * 0.1}"
    )
    config.original_probability = (
        "${eval: ${switch:{small:0.15, base:0.15, large:0.15},${scale}} * 0.1}"
    )
    config.batch_size = "${switch:{small:128, base:256, large:288},${scale}}"
    config.num_steps = "${switch:{small:1450000, base:1000000, large:1000000},${scale}}"
    config.learning_rate: float = "${switch:{small:5e-4, base:2e-4, large:1e-4},${scale}}"
    config.mixed_precision_init_scale = 2 ** 15


def electra(config):
    config.use_electra = True
    config.mask_probability = "${switch:{small:0.15, base:0.15, large:0.25},${scale}}"
    config.num_steps = "${switch:{small:1000000, base:766000, large:400000},${scale}}"
    config.batch_size = "${switch:{small:128, base:256, large:2048},${scale}}"
    config.learning_rate = "${switch:{small:5e-4, base:2e-4, large:2e-4},${scale}}"
    config.mixed_precision_init_scale = 3000


def mlm_tsp(config):
    mlm(config)
    config.tsp_loss_weight = 1
    # a100: 12.09434986114502
    # gv100: 12.098687171936035


def mlm_nsp(config):
    mlm(config)
    config.inter_segment_task = "nsp"


def mlm_sop(config):
    mlm(config)
    config.inter_segment_task = "sop"


def mlm_sso(config):
    mlm(config)
    config.inter_segment_task = "sso"
