import torch
from torch.optim import SGD, Adam


# call Encoders and Decoders
from cg.models.decoders.base import Decoder1C, Decoder3C
from cg.models.decoders.lie_group import LieDecoder1C, LieDecoder3C
from cg.models.encoders.base import Encoder1C, Encoder3C
from cg.models.encoders.lie_group import LieEncoder1C, LieEncoder3C
# from cg.models.CMCS.base import Unsuper_Encoder_dsprites
from cg.models.CMCS.base import (
    dSprites_Classifier,
    Shapes3D_Classifier,
    MPI3D_Classifier,
    MPI3D_Complex_Classifier,
)

# call base dataloaders
from cg.datasets.dsprites import dSprites
from cg.datasets.shapes3d import Shapes3D
from cg.datasets.mpi3d_toy import MPI3D_toy

# call factor vae dataloaders
# from cg.datasets.factorvae.dsprites import dSprites_f
# from cg.datasets.factorvae.shapes3d import Shapes3D_f
# from cg.datasets.factorvae.mpi3d_toy import MPI3D_toy_f
# from cg.datasets.factorvae.mpi3d_real import MPI3D_real_f
# from cg.datasets.factorvae.mpi3d_complex import MPI3D_complex_f

from cg.datasets.r2e import r2e_dsprites, r2e_shape3d, r2e_mpi3dtoy
from cg.datasets.r2r import r2r_shape3d, r2r_dsprites, r2r_mpi3dtoy

BASE_EN_DE = {
    "dsprites": (Encoder1C, Decoder1C),
    "shapes3d": (Encoder3C, Decoder3C),
    "car": (Encoder3C, Decoder3C),
    "celeba": (Encoder3C, Decoder3C),
    "cdsprites": (Encoder3C, Decoder3C),
    "mpi3d_toy": (Encoder3C, Decoder3C),
}

LIE_EN_DE = {
    "dsprites": (LieEncoder1C, LieDecoder1C),
    "shapes3d": (LieEncoder3C, LieDecoder3C),
    "car": (LieEncoder3C, LieDecoder3C),
    "celeba": (LieEncoder3C, LieDecoder3C),
    "cdsprites": (LieEncoder3C, LieDecoder3C),
    "mpi3d_toy": (LieEncoder3C, LieDecoder3C),
}


# CMCS_UNSUPER_EN_DE = {
#     "dsprites": (Unsuper_Encoder_dsprites, Decoder1C),
# }

DATA_HIDDEN_DIM = {
    "dsprites": [256, 128],
    "shapes3d": [256, 256],
    "car": [256, 256],
    "smallnorb": [256, 256],
    "celeba": [256, 256],
    "cdsprites": [256, 256],
    "mpi3d_toy": [256, 256],
    "mpi3d_real": [256, 256],
    "mpi3d_complex": [256, 256],
}

DATA_STEPS = {
    "dsprites": 300000,
    "shapes3d": 500000,
    "car": 300000,
    "smallnorb": 500000,
    "celeba": 1000000,
    "cdsprites": 600000,
    "mpi3d_toy": 1000000,
    "mpi3d_real": 1000000,
    "mpi3d_complex": 500000,
}

BASE_DATA = {
    "dsprites": dSprites,
    "shapes3d": Shapes3D,  # _3DshapeDataLoader,
    "mpi3d_toy": MPI3D_toy,
}
# Factor_DATA = {
#     "dsprites": dSprites_f,
#     "shapes3d": Shapes3D_f,  # _3DshapeDataLoader,
#     "mpi3d_toy": MPI3D_toy_f,
#     "mpi3d_real": MPI3D_real_f,
#     "mpi3d_complex": MPI3D_complex_f,
# }

OPTIMIZER = {
    "sgd": SGD,
    "adam": Adam,
}

FACTOR_INFORM = {
    "dsprites": torch.tensor([3, 6, 40, 32, 32]),
    "shapes3d": torch.tensor([10, 10, 10, 8, 4, 15]),
    "mpi3d_toy": torch.tensor([6, 6, 2, 3, 3, 40, 40]),
    "mpi3d_real": torch.tensor([6, 6, 2, 3, 3, 40, 40]),
    "mpi3d_complex": torch.tensor([4, 4, 2, 3, 3, 40, 40]),
}
#
FACTOR_CLASSIFIER = {
    "dsprites": dSprites_Classifier,
    "shapes3d": Shapes3D_Classifier,
    "mpi3d_toy": MPI3D_Classifier,
    "mpi3d_real": MPI3D_Classifier,
    "mpi3d_complex": MPI3D_Complex_Classifier,
}

PRIOR_LIST = {
    "dsprites": [3, 6, 40, 32, 32, 10, 10, 10, 10, 10],
    "shapes3d": [10, 10, 10, 8, 4, 15],
    "mpi3d_toy": [6, 6, 2, 3, 3, 40, 40, 10, 10, 10],
    "mpi3d_real": [6, 6, 2, 3, 3, 40, 40, 10, 10, 10],
    "mpi3d_complex": [4, 4, 2, 3, 3, 40, 40, 10, 10, 10]
}

# r2e r2r separateors
R2E_R2R = {"r2e": {
                "dsprites": r2e_dsprites,
                "shapes3d": r2e_shape3d,
                "mpi3d_toy": r2e_mpi3dtoy
            },
           "r2r": {
               "dsprites": r2r_dsprites,
               "shapes3d": r2r_shape3d,
               "mpi3d_toy": r2r_mpi3dtoy
           }}

