import glob
import os
from typing import Tuple

import numpy as np
import torch
from datasets import load_dataset
from loguru import logger
from sklearn.preprocessing import QuantileTransformer
from torch.utils.data import DataLoader, Subset, TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
from tqdm import tqdm

torch.set_num_threads(1)

split_ratio = 0.8


class CacheDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataloader: DataLoader):
        data = []
        targets = []

        for img, label in original_dataloader:
            data.append(img)
            targets.append(label)

        self.data = torch.cat(data)
        self.targets = torch.cat(targets)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img = self.data[index]
        target = self.targets[index]
        return img, target


def get_dataset(dataset_name: str) -> Tuple[TensorDataset, TensorDataset]:
    if dataset_name == "mnist":
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        dataset1 = datasets.MNIST(
            "../data", train=True, download=True, transform=transform
        )
        dataset2 = datasets.MNIST("../data", train=False, transform=transform)
    elif dataset_name == "cifar10":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        dataset1 = datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform
        )
        dataset2 = datasets.CIFAR10(root="./data", train=False, transform=transform)
    elif dataset_name == "cifar100":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
                ),
            ]
        )
        dataset1 = datasets.CIFAR100(
            root="./data", train=True, download=True, transform=transform
        )
        dataset2 = datasets.CIFAR100(root="./data", train=False, transform=transform)
    else:  # tabular benchmark
        dataset = load_dataset(
            "inria-soda/tabular-benchmark", data_files=f"clf_num/{dataset_name}.csv"
        ).with_format("torch")
        df = dataset["train"].to_pandas()
        df = df.sample(frac=1.0, random_state=0)  # random shuffle
        X = df.iloc[:, :-1].values
        y = df.iloc[:, -1].values
        transformer = QuantileTransformer(output_distribution="normal", random_state=0)
        X = transformer.fit_transform(X)
        unique_labels = np.unique(y)
        label_to_int = {key: value for value, key in enumerate(unique_labels)}
        y = np.array([label_to_int[label] for label in y])

        size = min(len(X) // 2, 10000)
        dataset1 = TensorDataset(
            torch.tensor(X[0:size]), torch.tensor(y[0:size], dtype=torch.int64)
        )
        dataset2 = TensorDataset(
            torch.tensor(X[size : 2 * size]),
            torch.tensor(y[size : 2 * size], dtype=torch.int64),
        )

    return dataset1, dataset2


def get_dataloader(
    dataset_name: str, batch_size: int, cpu: bool = False
) -> Tuple[DataLoader, DataLoader]:
    logger.info("Preparing normal dataloader...")
    train_kwargs = {
        "batch_size": batch_size,
        "shuffle": True,
        "pin_memory": False if cpu else True,
        "num_workers": 1,
    }
    test_kwargs = {
        "batch_size": batch_size,
        "shuffle": False,
        "pin_memory": False if cpu else True,
        "num_workers": 1,
    }

    yaml_paths = glob.glob("config/dataset/*.yaml")
    dataset_candidates = [
        os.path.splitext(os.path.basename(filepath))[0] for filepath in yaml_paths
    ]
    assert dataset_name in dataset_candidates

    dataset1, dataset2 = get_dataset(dataset_name)

    if dataset_name in ("mnist", "cifar10", "cifar100"):
        _train_loader = DataLoader(dataset1, batch_size=batch_size, shuffle=False)
        _test_loader = DataLoader(dataset2, batch_size=batch_size, shuffle=False)
        logger.info("Preparing dataset cache...")
        cache_train_dataset = CacheDataset(_train_loader)
        cache_test_dataset = CacheDataset(_test_loader)
        train_loader = DataLoader(cache_train_dataset, **train_kwargs)
        test_loader = DataLoader(cache_test_dataset, **test_kwargs)
    else:
        train_loader = DataLoader(dataset1, **train_kwargs)
        test_loader = DataLoader(dataset2, **test_kwargs)
    logger.info("Dataloader preparation is completed")

    return train_loader, test_loader


def get_split_dataloader(
    dataset_name: str, batch_size: int, cpu: bool = False
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    logger.info("Preparing split dataloader...")

    def split_indices(n, split):
        indices = np.arange(n)
        np.random.shuffle(indices)
        split_point = int(n * split)
        return indices[:split_point], indices[split_point:]

    def create_balanced_sampler(dataset, split_ratio):
        indices_first_group = []
        indices_second_group = []

        n_classes = len(
            torch.unique(torch.Tensor([dataset[i][1] for i in range(len(dataset))]))
        )
        for i in range(n_classes):
            if dataset_name in ("mnist", "cifar10", "cifar100"):
                indices = np.where(np.array(dataset.targets) == i)[0]
            else:
                targets = [int(target) for (_, target) in dataset]
                indices = np.where(np.array(targets) == i)[0]

            if i < n_classes // 2:
                first_indices, second_indices = split_indices(len(indices), split_ratio)
            else:
                second_indices, first_indices = split_indices(len(indices), split_ratio)

            indices_first_group.extend(indices[first_indices])
            indices_second_group.extend(indices[second_indices])

        sampler_first_group = SubsetRandomSampler(indices_first_group)
        sampler_second_group = SubsetRandomSampler(indices_second_group)

        return sampler_first_group, sampler_second_group

    train_kwargs = {
        "batch_size": batch_size,
        "pin_memory": False if cpu else True,
        "num_workers": 1,
    }
    test_kwargs = {
        "batch_size": batch_size,
        "shuffle": False,
        "pin_memory": False if cpu else True,
        "num_workers": 1,
    }

    yaml_paths = glob.glob("config/dataset/*.yaml")
    dataset_candidates = [
        os.path.splitext(os.path.basename(filepath))[0] for filepath in yaml_paths
    ]
    assert dataset_name in dataset_candidates

    dataset1, dataset2 = get_dataset(dataset_name)
    sampler_first_group, sampler_second_group = create_balanced_sampler(
        dataset1, split_ratio
    )

    if dataset_name in ("mnist", "cifar10", "cifar100"):
        _train_loader_a = DataLoader(
            dataset1, batch_size=batch_size, sampler=sampler_first_group
        )
        _train_loader_b = DataLoader(
            dataset1, batch_size=batch_size, sampler=sampler_second_group
        )
        _test_loader = DataLoader(dataset2, batch_size=batch_size)
        logger.info("Preparing dataset cache...")
        cache_train_dataset_a = CacheDataset(_train_loader_a)
        cache_train_dataset_b = CacheDataset(_train_loader_b)
        cache_test_dataset = CacheDataset(_test_loader)
        train_loader_a = DataLoader(cache_train_dataset_a, **train_kwargs)
        train_loader_b = DataLoader(cache_train_dataset_b, **train_kwargs)
        test_loader = DataLoader(cache_test_dataset, **test_kwargs)

        cache_train_dataset_a = CacheDataset(_train_loader_a)
        cache_train_dataset_b = CacheDataset(_train_loader_b)
        cache_test_dataset = CacheDataset(_test_loader)
        train_loader_a = DataLoader(cache_train_dataset_a, **train_kwargs)
        train_loader_b = DataLoader(cache_train_dataset_b, **train_kwargs)
        test_loader = DataLoader(cache_test_dataset, **test_kwargs)
    else:
        train_loader_a = DataLoader(
            dataset1, sampler=sampler_first_group, **train_kwargs
        )
        train_loader_b = DataLoader(
            dataset1, sampler=sampler_second_group, **train_kwargs
        )
        test_loader = DataLoader(dataset2, **test_kwargs)
    logger.info("Dataloader preparation is completed")
    return train_loader_a, train_loader_b, test_loader


if __name__ == "__main__":
    yaml_paths = glob.glob("config/dataset/*.yaml")
    dataset_candidates = sorted(
        [os.path.splitext(os.path.basename(filepath))[0] for filepath in yaml_paths]
    )

    logger.info("normal dataloader")
    for dataset_name in tqdm(dataset_candidates):
        logger.info(f"dataset_name={dataset_name}")
        train_loader, test_loader = get_dataloader(dataset_name, batch_size=512)
        train_size, test_size = 0, 0
        for data, target in train_loader:
            train_size += len(data)
        for data, target in test_loader:
            test_size += len(data)
        logger.info(f"train-size: {train_size}, test-size: {test_size}")
    logger.info("split dataloader")
    for dataset_name in tqdm(dataset_candidates):
        logger.info(f"dataset_name={dataset_name}")
        train_loader_a, train_loader_b, test_loader = get_split_dataloader(
            dataset_name, batch_size=512
        )
        train_a_size, train_b_size, test_size = 0, 0, 0
        target_a, target_b = 0, 0
        for data, target in train_loader_a:
            train_a_size += len(data)
            target_a += torch.sum(target)
        for data, target in train_loader_b:
            train_b_size += len(data)
            target_b += torch.sum(target)
        for data, target in test_loader:
            test_size += len(data)
        logger.info(
            f"train-a-size: {train_a_size}, train-b-size: {train_b_size}, test-size: {test_size}, target_a_sum: {target_a}, target_b_sum: {target_b}"
        )
