import os

import torch
import torchvision.transforms as transforms
import torchvision.utils
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST
from torch.utils.data import Subset

from dataset import DiffusionMNIST, DiffusionCIFAR
import utils

img_transform_list = [transforms.ToTensor()]

MNIST_transform_list = [
    transforms.Pad(2),
    transforms.ToTensor(),
]

binMNIST_transform_list = [
    transforms.Pad(2),
    transforms.ToTensor(),
    lambda x: torch.round(x),
]

dm_binMNIST_transform_list = [
    transforms.ToTensor(),
    lambda x: torch.round(x),
]

FashinMNIST_transform_list = [
    transforms.Pad(2),
    transforms.ToTensor(),
]

dm_FashinMNIST_transform_list = [
    transforms.ToTensor(),
]

class DLoader:
    def __init__(
        self,
        dataset_name,
        batch_size,
        seed,
        train_transform=None,
        path=None,
        augment=0.0,
        subset_portion=1.0,
        image_directory_cutoff_index=None,
    ):
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.path = path
        self.diffusion_train_dataset = None
        self.subset_portion = subset_portion
        self.subset_train_dataset = None
        self.seed = seed
        self.non_augmented_train_dataloader = None
        self.image_directory_cutoff_index = image_directory_cutoff_index

    def load_data(self):
        download = False

        if self.dataset_name == "CIFAR10":
            train_dataset = CIFAR10(
                root="../data/CIFAR10",
                train=True,
                download=download,
                transform=torchvision.transforms.Compose(img_transform_list),
            )
            train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            test_dataset = CIFAR10(
                root="../data/CIFAR10",
                train=False,
                download=download,
                transform=torchvision.transforms.Compose(img_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "CIFAR10_Sub":
            if self.subset_train_dataset is None:
                utils.set_seed(self.seed)
                full_train_dataset = CIFAR10(
                    root="../data/CIFAR10",
                    train=True,
                    download=download,
                    transform=torchvision.transforms.Compose(img_transform_list),
                )
                # Get the length of the full dataset
                lengths = len(full_train_dataset)
                # Calculate the number of samples for `subset_portion`
                lengths_subset = int(lengths * self.subset_portion)
                print(f"Using {lengths_subset} samples out of {lengths} samples for training.")
                # Generate a random subset of indices
                indices = torch.randperm(lengths)[:lengths_subset]
                # Create the subset
                self.subset_train_dataset = Subset(full_train_dataset, indices)
            
            train_dataloader = DataLoader(
                self.subset_train_dataset, batch_size=self.batch_size, shuffle=True
            )
            self.non_augmented_train_dataloader = DataLoader(
                self.subset_train_dataset, batch_size=self.batch_size, shuffle=True
            )
            test_dataset = CIFAR10(
                root="../data/CIFAR10",
                train=False,
                download=download,
                transform=torchvision.transforms.Compose(img_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "Diffusion-CIFAR10":
            if self.diffusion_train_dataset is None:
                assert self.path is not None
                image_directory_paths = [
                    os.path.join(self.path, p) for p in os.listdir(self.path) if'-seeds-' in p
                ]
                self.diffusion_train_dataset = DiffusionCIFAR(
                    image_directory_paths=image_directory_paths,
                    transform=torchvision.transforms.Compose(img_transform_list),
                    size_cap=50000,
                    image_directory_cutoff_index=self.image_directory_cutoff_index,
                )
            else:
                self.diffusion_train_dataset.load_next_dataset()

            train_dataset = CIFAR10(
                root="../data/CIFAR10",
                train=True,
                download=download,
                transform=torchvision.transforms.Compose(img_transform_list),
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )

            train_dataloader = DataLoader(
                self.diffusion_train_dataset, batch_size=self.batch_size, shuffle=True
            )
            test_dataset = CIFAR10(
                root="../data/CIFAR10",
                train=False,
                download=download,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(img_transform_list),
            )
            test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)

        elif self.dataset_name == "MNIST":
            train_dataset = MNIST(
                root="../data",
                download=download,
                train=True,
                transform=torchvision.transforms.Compose(MNIST_transform_list),
            )
            train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            test_dataset = MNIST(
                root="../data",
                download=download,
                train=False,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(MNIST_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "Diffusion-MNIST":
            if self.diffusion_train_dataset is None:
                assert self.path is not None
                image_directory_paths = [
                    os.path.join(self.path, p)
                    for p in os.listdir(self.path)
                    if "-seeds-" in p
                ]
                self.diffusion_train_dataset = DiffusionMNIST(
                    image_directory_paths=image_directory_paths,
                    transform=torchvision.transforms.Compose(MNIST_transform_list),
                    image_directory_cutoff_index=self.image_directory_cutoff_index,
                )
            else:
                self.diffusion_train_dataset.load_next_dataset()

            train_dataloader = DataLoader(
                self.diffusion_train_dataset, batch_size=self.batch_size, shuffle=True
            )

            train_dataset = MNIST(
                root="../data",
                download=download,
                train=True,
                transform=torchvision.transforms.Compose(MNIST_transform_list),
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )

            test_dataset = MNIST(
                root="../data",
                download=download,
                train=False,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(MNIST_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "BinaryMNIST":
            train_dataset = MNIST(
                root="../data",
                download=download,
                train=True,
                transform=torchvision.transforms.Compose(binMNIST_transform_list),
            )
            train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            test_dataset = MNIST(
                root="../data",
                download=download,
                train=False,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(binMNIST_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "Diffusion-BinaryMNIST":
            if self.diffusion_train_dataset is None:
                assert self.path is not None
                image_directory_paths = [
                    os.path.join(self.path, p)
                    for p in os.listdir(self.path)
                    if "-seeds-" in p
                ]
                self.diffusion_train_dataset = DiffusionMNIST(
                    image_directory_paths=image_directory_paths,
                    transform=torchvision.transforms.Compose(
                        dm_binMNIST_transform_list
                    ),
                    image_directory_cutoff_index=self.image_directory_cutoff_index,
                )
            else:
                self.diffusion_train_dataset.load_next_dataset()

            train_dataloader = DataLoader(
                self.diffusion_train_dataset, batch_size=self.batch_size, shuffle=True
            )

            train_dataset = MNIST(
                root="../data",
                download=download,
                train=True,
                transform=torchvision.transforms.Compose(binMNIST_transform_list),
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )

            test_dataset = MNIST(
                root="../data",
                download=download,
                train=False,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(binMNIST_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "FashionMNIST":
            train_dataset = FashionMNIST(
                root="../data",
                download=download,
                train=True,
                transform=torchvision.transforms.Compose(FashinMNIST_transform_list),
            )
            train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )
            test_dataset = FashionMNIST(
                root="../data",
                download=download,
                train=False,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(FashinMNIST_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        elif self.dataset_name == "Diffusion-FashionMNIST":
            if self.diffusion_train_dataset is None:
                assert self.path is not None
                image_directory_paths = [
                    os.path.join(self.path, p)
                    for p in os.listdir(self.path)
                    if "-seeds-" in p
                ]
                self.diffusion_train_dataset = DiffusionMNIST(
                    image_directory_paths=image_directory_paths,
                    transform=torchvision.transforms.Compose(
                        dm_FashinMNIST_transform_list
                    ),
                    image_directory_cutoff_index=self.image_directory_cutoff_index,
                )
            else:
                self.diffusion_train_dataset.load_next_dataset()

            train_dataloader = DataLoader(
                self.diffusion_train_dataset, batch_size=self.batch_size, shuffle=True
            )

            train_dataset = FashionMNIST(
                root="../data",
                download=download,
                train=True,
                transform=torchvision.transforms.Compose(FashinMNIST_transform_list),
            )
            self.non_augmented_train_dataloader = DataLoader(
                train_dataset, batch_size=self.batch_size, shuffle=True
            )

            test_dataset = FashionMNIST(
                root="../data",
                download=download,
                train=False,
                # no augmentation during testing
                transform=torchvision.transforms.Compose(FashinMNIST_transform_list),
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=self.batch_size, shuffle=True
            )

        return train_dataloader, test_dataloader
