from copy import deepcopy
import pdb
import torch
import torch.utils.data as data
import numpy as np


from dl.datasets.dsprites import dSprites
from dl.datasets.shapes3d import Shapes3D
from dl.datasets.mpi3d_toy import MPI3D_toy

def r2r_dsprites(dataset: dSprites) -> (dSprites, dSprites):
    drop_list = list()
    shape = 0
    for scale in range(0, 6, 1):
        for orientation in range(0, 40, 1):
            for x in range(16, 32, 1):
                for y in range(0, 32, 1):
                    to_drop = np.array([shape, scale, orientation, x, y])
                    drop_list.append(dataset.factor_to_idx(to_drop))

    drop_list = np.array(drop_list)
    data_list = np.arange(len(dataset))
    selected_list = np.delete(data_list, drop_list)

    train_dataset = deepcopy(dataset)
    test_dataset = deepcopy(dataset)

    # train_dataset.data.sort()
    # test_dataset.data.sort()

    train_dataset.data = train_dataset.data[selected_list]
    train_dataset.latents_values = train_dataset.latents_values[selected_list]
    train_dataset.latents_classes =train_dataset.latents_classes[selected_list]

    test_dataset.data = test_dataset.data[drop_list]
    test_dataset.latents_values = test_dataset.latents_values[drop_list]
    test_dataset.latents_classes = test_dataset.latents_classes[drop_list]

    return train_dataset, test_dataset


def r2r_shape3d(dataset: Shapes3D) -> (Shapes3D, Shapes3D):
    drop_list = list()
    shape = 3
    for floor_hue in range(0, 10, 1):
        for wall_hue in range(0, 10, 1):
            for object_hue in range(6, 10, 1):
                for scale in range(0, 8, 1):
                    for orientation in range(0, 15, 1):
                        to_drop = np.array(
                            [
                                floor_hue,
                                wall_hue,
                                object_hue,
                                scale,
                                shape,
                                orientation,
                            ]
                        )
                        drop_list.append(dataset.factor_to_idx(to_drop))

    drop_list = np.array(drop_list)
    data_list = np.arange(len(dataset))
    selected_list = np.delete(data_list, drop_list)

    train_dataset = deepcopy(dataset)
    test_dataset = deepcopy(dataset)

    # train_dataset.data.sort()
    # test_dataset.data.sort()

    train_dataset.data = train_dataset.data[selected_list]
    train_dataset.latents_values = train_dataset.latents_values[selected_list]
    train_dataset.latents_classes =train_dataset.latents_classes[selected_list]

    test_dataset.data = test_dataset.data[drop_list]
    test_dataset.latents_values = test_dataset.latents_values[drop_list]
    test_dataset.latents_classes = test_dataset.latents_classes[drop_list]

    return train_dataset, test_dataset

def r2r_mpi3dtoy(dataset: MPI3D_toy) -> (MPI3D_toy, MPI3D_toy):
    drop_list = list()
    shape = 2
    for objcol in range(0, 6, 1):
        for objsize in range(0, 2, 1):
            for cameraheight in range(0, 3, 1):
                for backcol in range(0, 3, 1):
                    for posx in range(0, 40, 1):
                        for posy in range(20, 40, 1):
                            # if scale == 8:
                            #     pdb.set_trace()
                            # for orientation in range(0, 15, 1):
                            to_drop = np.array(
                                [
                                    objcol,
                                    shape,
                                    objsize,
                                    cameraheight,
                                    backcol,
                                    posx,
                                    posy,
                                ]
                            )
                            drop_list.append(dataset.factor_to_idx(to_drop))

    drop_list = np.array(drop_list)
    data_list = np.arange(len(dataset))
    selected_list = np.delete(data_list, drop_list)

    train_dataset = deepcopy(dataset)
    test_dataset = deepcopy(dataset)

    # train_dataset.data.sort()
    # test_dataset.data.sort()

    train_dataset.data = train_dataset.data[selected_list]
    train_dataset.latents_values = train_dataset.latents_values[selected_list]
    train_dataset.latents_classes =train_dataset.latents_classes[selected_list]

    test_dataset.data = test_dataset.data[drop_list]
    test_dataset.latents_values = test_dataset.latents_values[drop_list]
    test_dataset.latents_classes = test_dataset.latents_classes[drop_list]

    return train_dataset, test_dataset
