import torch
from torch.optim import SGD, Adam


# call Encoders and Decoders
from dl.models.decoders.base import Decoder1C, Decoder3C, MNISTDecoder1C
from dl.models.decoders.lie_group import LieDecoder1C, LieDecoder3C
from dl.models.encoders.base import Encoder1C, Encoder3C, MNISTEncoder1C
from dl.models.encoders.lie_group import LieEncoder1C, LieEncoder3C
from dl.models.CMCS.base import Unsuper_Encoder_dsprites
from dl.models.CMCS.base import (
    dSprites_Classifier,
    Shapes3D_Classifier,
    MPI3D_Classifier,
    MPI3D_Complex_Classifier,
)


# call base dataloaders
from dl.datasets.dsprites import dSprites
from dl.datasets.shapes3d import Shapes3D
from dl.datasets.mpi3d_toy import MPI3D_toy
from dl.datasets.mpi3d_real import MPI3D_real
from dl.datasets.mpi3d_complex import MPI3D_complex
# call factor vae dataloaders
from dl.datasets.factorvae.dsprites import dSprites_f
from dl.datasets.factorvae.shapes3d import Shapes3D_f
from dl.datasets.factorvae.mpi3d_toy import MPI3D_toy_f
from dl.datasets.factorvae.mpi3d_real import MPI3D_real_f
from dl.datasets.factorvae.mpi3d_complex import MPI3D_complex_f


# call cmcs-semisuper dataloaders
from dl.datasets.cmcs_semisuper.dsprites import dSprites_Semi
from dl.datasets.cmcs_semisuper.mpi3d_toy import MPI3D_toy_Semi


BASE_EN_DE = {
    "dsprites": (Encoder1C, Decoder1C),
    "shapes3d": (Encoder3C, Decoder3C),
    "car": (Encoder3C, Decoder3C),
    "celeba": (Encoder3C, Decoder3C),
    "cdsprites": (Encoder3C, Decoder3C),
    "mpi3d_toy": (Encoder3C, Decoder3C),
    "mpi3d_real": (Encoder3C, Decoder3C),
    "mpi3d_complex": (Encoder3C, Decoder3C),
    "mmnist": (MNISTEncoder1C, MNISTDecoder1C),
}

LIE_EN_DE = {
    "dsprites": (LieEncoder1C, LieDecoder1C),
    "shapes3d": (LieEncoder3C, LieDecoder3C),
    "car": (LieEncoder3C, LieDecoder3C),
    "celeba": (LieEncoder3C, LieDecoder3C),
    "cdsprites": (LieEncoder3C, LieDecoder3C),
    "mpi3d_toy": (LieEncoder3C, LieDecoder3C),
    "mpi3d_real": (LieEncoder3C, LieDecoder3C),
    "mpi3d_complex": (LieEncoder3C, LieDecoder3C),
}


CMCS_UNSUPER_EN_DE = {
    "dsprites": (Unsuper_Encoder_dsprites, Decoder1C),
}


DATA_HIDDEN_DIM = {
    "dsprites": [256, 128],
    "shapes3d": [256, 256],
    "car": [256, 256],
    "smallnorb": [256, 256],
    "celeba": [256, 256],
    "cdsprites": [256, 256],
    "mpi3d_toy": [256, 256],
    "mpi3d_real": [256, 256],
    "mpi3d_complex": [256, 256],
    "mmnist": [256, 128],
}

DATA_STEPS = {
    "dsprites": 300000,
    "shapes3d": 500000,
    "car": 300000,
    "smallnorb": 500000,
    "celeba": 1000000,
    "cdsprites": 600000,
    "mpi3d_toy": 1000000,
    "mpi3d_real": 1000000,
    "mpi3d_complex": 500000,
    "mmnist": 50000,
}

BASE_DATA = {
    "dsprites": dSprites,
    "shapes3d": Shapes3D,  # _3DshapeDataLoader,
    "mpi3d_toy": MPI3D_toy,
    "mpi3d_real": MPI3D_real,
    "mpi3d_complex": MPI3D_complex,
}
Factor_DATA = {
    "dsprites": dSprites_f,
    "shapes3d": Shapes3D_f,  # _3DshapeDataLoader,
    "mpi3d_toy": MPI3D_toy_f,
    "mpi3d_real": MPI3D_real_f,
    "mpi3d_complex": MPI3D_complex_f,
}

SEMI_DATA = {
    "dsprites": dSprites_Semi,
    "shapes3d": Shapes3D,  # _3DshapeDataLoader,
    "mpi3d_toy": MPI3D_toy_Semi,
    "mpi3d_real": MPI3D_real,
    "mpi3d_complex": MPI3D_complex,
}


OPTIMIZER = {
    "sgd": SGD,
    "adam": Adam,
}

FACTOR_INFORM = {
    "dsprites": torch.tensor([3, 6, 40, 32, 32]),
    "shapes3d": torch.tensor([10, 10, 10, 8, 4, 15]),
    "mpi3d_toy": torch.tensor([6, 6, 2, 3, 3, 40, 40]),
    "mpi3d_real": torch.tensor([6, 6, 2, 3, 3, 40, 40]),
    "mpi3d_complex": torch.tensor([4, 4, 2, 3, 3, 40, 40]),
}

FACTOR_CLASSIFIER = {
    "dsprites": dSprites_Classifier,
    "shapes3d": Shapes3D_Classifier,
    "mpi3d_toy": MPI3D_Classifier,
    "mpi3d_real": MPI3D_Classifier,
    "mpi3d_complex": MPI3D_Complex_Classifier,
}

PRIOR_LIST = {
    "dsprites": [3, 6, 40, 32, 32, 10, 10, 10, 10, 10],
    "shapes3d": [10, 10, 10, 8, 4, 15],
    "mpi3d_toy": [6, 6, 2, 3, 3, 40, 40, 10, 10, 10],
    "mpi3d_real": [6, 6, 2, 3, 3, 40, 40, 10, 10, 10],
    "mpi3d_complex": [4, 4, 2, 3, 3, 40, 40, 10, 10, 10]
}