from latentis.transform.translate.aligner import Translator, MatrixAligner
from layskip.modules.mlp_translator import SGDMLPAligner
from latentis.transform.translate.functional import lstsq_align_state
from latentis.transform.base import StandardScaling


DATASET2IMAGE_COLUMN = {
    "mnist": "image",
    "fashion-mnist": "image",
    "imagenet-1k": "image",
    "cifar10": "img",
    "cifar100": "img",
    "cifar100-fine": "img",
    "cifar100-coarse": "img",
}

DATASET2LABEL_COLUMN = {
    "mnist": "label",
    "fashion-mnist": "label",
    "imagenet-1k": "label",
    "cifar10": "label",
    "cifar100": "fine_label",
    "cifar100-fine": "fine_label",
    "cifar100-coarse": "coarse_label",
}

DATASET2NUM_CLASSES = {
    "mnist": 10,
    "fashion-mnist": 10,
    "imagenet-1k": 1000,
    "cifar10": 10,
    "cifar100": 100,
    "cifar100-fine": 100,
    "cifar100-coarse": 20,
}

MODEL_NAME2HF_NAME = {
    "vit-small-patch16-224": "WinKawaks/vit-small-patch16-224",
    "deit-small-patch16-224": "facebook/deit-small-patch16-224",
    "dinov2-small": "facebook/dinov2-small",
}

DATASET_NAME2HF_NAME = {
    "mnist": "mnist",
    "fashion-mnist": "zalando-datasets/fashion_mnist",
    "imagenet-1k": "ILSVRC/imagenet-1k",
    "cifar10": "cifar10",
    "cifar100": "cifar100",
    "cifar100-fine": "cifar100",
    "cifar100-coarse": "cifar100",
}

MODEL2NUM_LAYERS = {
    "WinKawaks/vit-small-patch16-224": 12,
    "facebook/deit-small-patch16-224": 12,
    "microsoft/beit-base-patch16-224": 12,
    "facebook/dinov2-small": 12,
    "microsoft/swinv2-tiny-patch4-window8-256": 4,
}

NAME2TRANSLATORS = {
    "linear": lambda: Translator(
        aligner=MatrixAligner(name="linear", align_fn_state=lstsq_align_state),
    ),
    "sgd_mlp_aligner": lambda: Translator(
        aligner=SGDMLPAligner(num_steps=50, lr=1e-3, random_seed=0),
        x_transform=StandardScaling(),
        y_transform=StandardScaling(),
    ),
}
