from __future__ import annotations

from dataclasses import dataclass, fields
from functools import partial
from pathlib import Path
from typing import Any, Literal
import yaml

from saws.warmstart.utils import (
    pad_zeros_model,
    slice_from_base_to_target,
    shrink_and_perturb,
    zero_centered_mup_perturb,
)
from saws.warmstart.hyperclone import hyperclone

WARMSTARTING_MAP = {
    # Adds 0 to the extra dimensions in the target model.
    "zeros": pad_zeros_model,
    # Retains muP initialization for the extra dimensions in the target model.
    "zeros_with_mup": partial(
        shrink_and_perturb,
        warm_type="zeros",
        shrinking_factor=1,
        mup_init=True
    ),
    # Zero padding, shrinking with SnP-defaults and muP initialization as perturbation.
    "snp_zeros_with_mup": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True
    ),

    # Samples rows-columns with replacement from the base model to the target model.
    "slice": slice_from_base_to_target,
    # Samples rows-columns from the base model and perturbs the target model with muP initialization.
    "slice_with_mup": partial(
        shrink_and_perturb,
        warm_type="slice",
        shrinking_factor=1,
        mup_init=True
    ),
    # Samples rows-columns, shrinking with SnP-defaults and muP initialization as perturbation.
    "snp_slice_with_mup":  partial(
        shrink_and_perturb,
        warm_type="slice",
        mup_init=True
    ),
    
    # Zero padding and then Shrink-and-perturb from SnP-defaults.
    "snp_zeros": partial(shrink_and_perturb, warm_type="zeros", mup_init=False),
    # Slice padding and then Shrink-and-perturb from SnP-defaults.
    "snp_slice": partial(shrink_and_perturb, warm_type="slice", mup_init=False),

    "zeros_with_shrinking": partial(
        shrink_and_perturb,
        warm_type="zeros",
        perturbation_sigma=0,
        mup_init=False,
    ),
    "zeros_with_perturbation": partial(
        shrink_and_perturb,
        warm_type="zeros",
        shrinking_factor=1,
        mup_init=False,
    ),

    # cloning the base model
    "snp_clone_with_mup": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        mirror=False,
    ),
    "snp_clone_mirror_with_mup": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        mirror=True,
    ),

    # zero centering scaled weights instead of shrinking in SnP
    "centering_zeros_mup": partial(
        zero_centered_mup_perturb,
        warm_type="zeros",
    ),
    "centering_clone_mup": partial(
        zero_centered_mup_perturb,
        warm_type="clone",
        mirror=False,
    ),
    "centering_clone_mirror_mup": partial(
        zero_centered_mup_perturb,
        warm_type="clone",
        mirror=True,
    ),

    # # dynamic layer-wise shrinking
    # "dynamic_layer_wise_shrinking": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=True,
    #     shrinking_factor="layer-wise",  # dynamic layer-wise shrinking factors selected
    #     active_layer=None,
    # ),

    # snp without shrinking
    "snp_zeros_with_mup_no_shrink": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        shrinking_factor=1,
    ),
    # snp without shrinking only for hidden layers
    "snp_zeros_with_mup_no_shrink_hidden": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        shrinking_factor=1,
        active_layer="hidden",
    ),
    # snp without shrinking only for un-emdedding layer
    "snp_zeros_with_mup_no_shrink_readout": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        shrinking_factor=1,
        active_layer="readout",
    ),
    # Net2Net
    "net2net": partial(
        shrink_and_perturb,
        warm_type="slice",
        mup_init=False,
        shrinking_factor=1,
    ),
    # WS slicing with no-shrink snp
    "snp_slice_with_mup_no_shrink": partial(
        shrink_and_perturb,
        warm_type="slice",
        mup_init=True,
        shrinking_factor=1,
    ),
    # WS with only pad zero
    "pad_zeros": partial(
        shrink_and_perturb,
        warm_type="zeros",
        shrinking_factor=1,
        mup_init=False,
        perturbation_sigma=0,
    ),    
    # WS with only pad zero for embedding layer
    "pad_zeros_embedding": partial(
        shrink_and_perturb,
        warm_type="zeros",
        shrinking_factor=1,
        mup_init=False,
        perturbation_sigma=0,
        active_layer="input",
    ),
    # WS with only pad zero for hidden layers
    "pad_zeros_hidden": partial(
        shrink_and_perturb,
        warm_type="zeros",
        shrinking_factor=1,
        mup_init=False,
        perturbation_sigma=0,
        active_layer="hidden",
    ),
    # SNP zeros with no perturbation
    "snp_zeros_with_mup_no_perturb": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=False,
        perturbation_sigma=0,
    ),    
    "snp_zeros_with_mup_no_perturb_embedding": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="input",
    ),    
    "snp_zeros_with_mup_no_perturb_hidden": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="hidden",
    ),
    # SNP zeros with no perturbation only for un-embedding layer
    "snp_zeros_with_mup_no_perturb_readout": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="readout",
    ),
    # WS with only pad zero
    "pad_slice": partial(
        shrink_and_perturb,
        warm_type="slice",
        shrinking_factor=1,
        mup_init=False,
        perturbation_sigma=0,
    ),
    # SNP zeros with no perturbation
    "snp_slice_with_mup_no_perturb": partial(
        shrink_and_perturb,
        warm_type="slice",
        mup_init=False,
        perturbation_sigma=0,
    ),
    # SNP zeros with no perturbation only for embedding layer
    "snp_slice_with_mup_no_perturb_embedding": partial(
        shrink_and_perturb,
        warm_type="slice",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="input",
    ),
    # SNP zeros with no perturbation only for un-embedding layer
    "snp_slice_with_mup_no_perturb_readout": partial(
        shrink_and_perturb,
        warm_type="slice",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="readout",
    ),
    "snp_clone_with_mup_no_shrink": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        shrinking_factor=1,
    ),
    "snp_clone_with_mup_no_perturb": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=False,
        perturbation_sigma=0,
    ),
    "pad_clone": partial(
        shrink_and_perturb,
        warm_type="clone",
        shrinking_factor=1,
        mup_init=False,
        perturbation_sigma=0,
    ),
    "snp_clone_with_mup_no_perturb_embedding": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="input",
    ),    
    "snp_clone_with_mup_no_perturb_hidden": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="hidden",
    ),    
    "snp_clone_with_mup_no_perturb_readout": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=False,
        perturbation_sigma=0,
        active_layer="readout",
        mask_base=True,
    ),
    # hypercloning TODO: only for testing, should be extended to snp
    "hyperclone": hyperclone,


    ########################
    # # layer-wise influence

    # WS snp zeros with mup only on hidden layer (attention)
    "snp_zeros_with_mup_hidden": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        active_layer="hidden",  # `snp_zeros` + mup only on hidden layer
    ),
    # WS snp zeros with mup only on embedding layer
    "snp_zeros_with_mup_embedding": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        active_layer="input",  # `snp_zeros` + mup only on embedding layer
    ),
    # WS snp zeros with mup only on unembedding layer
    "snp_zeros_with_mup_readout": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        active_layer="readout",  # `snp_zeros` + mup only on unembedding layer
    ),    
    # WS snp zeros with mup only on embedding+unembedding layer
    "snp_zeros_with_mup_no-hidden": partial(
        shrink_and_perturb,
        warm_type="zeros",
        mup_init=True,
        active_layer="embeddings",  # `snp_zeros` + mup only on embedding+unembedding layer
    ),

    "snp_clone_with_mup_input": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        active_layer="input",  # `snp_clone` + mup only on input layer
    ),
    "snp_clone_with_mup_hidden": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        active_layer="hidden",  # `snp_clone` + mup only on hidden layer
    ),
    "snp_clone_with_mup_readout": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        active_layer="readout",  # `snp_clone` + mup only on readout layer
    ),
    "snp_clone_with_mup_no-hidden": partial(
        shrink_and_perturb,
        warm_type="clone",
        mup_init=True,
        active_layer="embeddings",  # `snp_zeros` + mup only on embedding+unembedding layer
    ),
}
    # "shrink_input_snp_zeros": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=False,
    #     active_layer="input",  # `snp_zeros` + mup only on input layer
    # ),
    # "shrink_hidden_snp_zeros": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=False,
    #     active_layer="hidden",  # `snp_zeros` + mup only on hidden layer
    # ),
    # "shrink_readout_snp_zeros": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=False,
    #     active_layer="readout",  # `snp_zeros` + mup only on readout layer
    # ),
    # # layer-wise influence but for slice
    #     "shrink_input_snp_slice_with_mup": partial(
    #     shrink_and_perturb,
    #     warm_type="slice",
    #     mup_init=True,
    #     active_layer="input",  # `snp_slice` + mup only on input layer
    # ),
    # "shrink_hidden_snp_slice_with_mup": partial(
    #     shrink_and_perturb,
    #     warm_type="slice",
    #     mup_init=True,
    #     active_layer="hidden",  # `snp_slice` + mup only on hidden layer
    # ),
    # "shrink_readout_snp_slice_with_mup": partial(
    #     shrink_and_perturb,
    #     warm_type="slice",
    #     mup_init=True,
    #     active_layer="readout",  # `snp_slice` + mup only on readout layer
    # ),
    # "shrink_input_snp_slice": partial(
    #     shrink_and_perturb,
    #     warm_type="slice",
    #     mup_init=False,
    #     active_layer="input",  # `snp_slice` + mup only on input layer
    # ),
    # "shrink_hidden_snp_slice": partial(
    #     shrink_and_perturb,
    #     warm_type="slice",
    #     mup_init=False,
    #     active_layer="hidden",  # `snp_slice` + mup only on hidden layer
    # ),
    # "shrink_readout_snp_slice": partial(
    #     shrink_and_perturb,
    #     warm_type="slice",
    #     mup_init=False,
    #     active_layer="readout",  # `snp_slice` + mup only on readout layer
    # ),

    # # base masking
    # "snp_clone_with_mup_mask": partial(
    #     shrink_and_perturb,
    #     warm_type="clone",
    #     mup_init=True,
    #     mirror=False,
    #     mask_base=True,
    # ),
    # "snp_zeros_with_mup_mask": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=True,
    #     mask_base=True,
    # ),
    # "shrink_input_snp_zeros_with_mup_mask": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=True,
    #     active_layer="input",
    #     mask_base=True,
    # ),
    # "shrink_hidden_snp_zeros_with_mup_mask": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=True,
    #     active_layer="hidden",
    #     mask_base=True,
    # ),
    # "shrink_readout_snp_zeros_with_mup_mask": partial(
    #     shrink_and_perturb,
    #     warm_type="zeros",
    #     mup_init=True,
    #     active_layer="readout",
    #     mask_base=True,
    # ),


WARMSTARTING_OPTIONS = list(WARMSTARTING_MAP.keys())


@dataclass
class WarmstartConfig:
    """Configuration class for warmstarting a model.
    """
    activate: bool = False
    base_model_path: str | None = None

    warmstart_type: Literal[WARMSTARTING_MAP.keys()] = "zeros"
    warmstarting_args: dict[str, Any] | None = None
    
    retain_optimizer: bool = False
    
    restart_dataloader: bool = True

    def __post_init__(self):
        if self.warmstarting_args is None:
            self.warmstarting_args = {}

        if self.activate:
            assert self.base_model_path, "Base model path must be provided for warmstarting."
            assert Path(self.base_model_path).exists(), \
                f"Base model path {self.base_model_path} does not exist."
        self.base_model_path = Path(self.base_model_path) if self.base_model_path else None

    def warmer(self) -> partial:
        return partial(
            WARMSTARTING_MAP[self.warmstart_type],
            **self.warmstarting_args
        )

    def is_active(self) -> bool:
        return self.activate

    def from_path(self, path: str | Path) -> None:
        with open(path, "r") as f:
            config = yaml.safe_load(f)
        for key, value in config.items():
            setattr(self, key, value)

    def to_dict(self):
        return {field.name: getattr(self, field.name) for field in fields(self)}
    
    def set_path(self, path: str | Path) -> None:
        if isinstance(path, str):
            path = Path(path)
        self.base_model_path = path

    __dict__ = to_dict
# end of file