from copy import deepcopy

import torch
import torch.utils.data as data
import numpy as np

from dl.datasets.dsprites import dSprites
from dl.datasets.shapes3d import Shapes3D
from dl.datasets.mpi3d_toy import MPI3D_toy

import pdb

def r2e_dsprites(dataset: dSprites) -> (dSprites, dSprites):
    drop_list = list()
    shape = 1
    scale = 1
    for orientation in range(14, 28, 1):
        for x in range(21, 32, 1):
            for y in range(21, 32, 1):
                to_drop = np.array([shape, scale, orientation, x, y])
                drop_list.append(dataset.factor_to_idx(to_drop))

    drop_list = np.array(drop_list)
    data_list = np.arange(len(dataset))
    selected_list = np.delete(data_list, drop_list)

    train_dataset = deepcopy(dataset)
    test_dataset = deepcopy(dataset)

    # train_dataset.data.sort()
    # test_dataset.data.sort()

    train_dataset.data = train_dataset.data[selected_list]
    train_dataset.latents_values = train_dataset.latents_values[selected_list]
    train_dataset.latents_classes =train_dataset.latents_classes[selected_list]

    test_dataset.data = test_dataset.data[drop_list]
    test_dataset.latents_values = test_dataset.latents_values[drop_list]
    test_dataset.latents_classes = test_dataset.latents_classes[drop_list]

    return train_dataset, test_dataset


def r2e_shape3d(dataset: Shapes3D) -> (Shapes3D, Shapes3D):
    drop_list = list()
    scale = 7
    shape = 1
    orientation = 7
    for floor_hue in range(6, 10, 1):
        for wall_hue in range(6, 10, 1):
            for object_hue in range(6, 10, 1):
                to_drop = np.array(
                    [
                        floor_hue,
                        wall_hue,
                        object_hue,
                        scale,
                        shape,
                        orientation,
                    ]
                )
                drop_list.append(dataset.factor_to_idx(to_drop))

    drop_list = np.array(drop_list)
    data_list = np.arange(len(dataset))
    selected_list = np.delete(data_list, drop_list)

    train_dataset = deepcopy(dataset)
    test_dataset = deepcopy(dataset)

    # train_dataset.data.sort()
    # test_dataset.data.sort()

    train_dataset.data = train_dataset.data[selected_list]
    train_dataset.latents_values = train_dataset.latents_values[selected_list]
    train_dataset.latents_classes =train_dataset.latents_classes[selected_list]

    test_dataset.data = test_dataset.data[drop_list]
    test_dataset.latents_values = test_dataset.latents_values[drop_list]
    test_dataset.latents_classes = test_dataset.latents_classes[drop_list]

    return train_dataset, test_dataset

def r2e_mpi3dtoy(dataset: MPI3D_toy) -> (MPI3D_toy, MPI3D_toy):
    drop_list = list()
    shape = 2
    objsize = 0
    cameraheight = 1
    backcol = 1
    for objcol in range(3, 6, 1):
        # for objsize in range(0, 2, 1):
        # for cameraheight in range(0, 3, 1):
        # for backcol in range(0, 3, 1):
        for posx in range(20, 40, 1):
            for posy in range(20, 40, 1):
                # if scale == 8:
                #     pdb.set_trace()
                # for orientation in range(0, 15, 1):
                to_drop = np.array(
                    [
                        objcol,
                        shape,
                        objsize,
                        cameraheight,
                        backcol,
                        posx,
                        posy,
                    ]
                )
                drop_list.append(dataset.factor_to_idx(to_drop))

    drop_list = np.array(drop_list)
    data_list = np.arange(len(dataset))
    selected_list = np.delete(data_list, drop_list)

    train_dataset = deepcopy(dataset)
    test_dataset = deepcopy(dataset)

    # train_dataset.data.sort()
    # test_dataset.data.sort()

    train_dataset.data = train_dataset.data[selected_list]
    train_dataset.latents_values = train_dataset.latents_values[selected_list]
    train_dataset.latents_classes =train_dataset.latents_classes[selected_list]

    test_dataset.data = test_dataset.data[drop_list]
    test_dataset.latents_values = test_dataset.latents_values[drop_list]
    test_dataset.latents_classes = test_dataset.latents_classes[drop_list]

    return train_dataset, test_dataset