import torch
from torchvision import datasets, transforms
from datasets import load_dataset
from src.data.dataset import get_indices


def create_mnist_dataset(train=True):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    return datasets.MNIST("../data", train=train, download=True, transform=transform)


def get_mnist_test_loaders(forget_classes=[7], batch_size=64, index_file=None):
    dataset = create_mnist_dataset(train=False)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, drop_last=False
    )
    forget_loader = None
    retain_loader = None
    if forget_classes is not None:
        forget_indices, retain_indices = get_indices(
            dataset, forget_classes, index_file
        )

        forget_ds = torch.utils.data.Subset(dataset, forget_indices)
        retain_ds = torch.utils.data.Subset(dataset, retain_indices)

        retain_loader = torch.utils.data.DataLoader(
            retain_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        forget_loader = torch.utils.data.DataLoader(
            forget_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    return {
        "test_full_loader": loader,
        "test_retain_loader": retain_loader,
        "test_forget_loader": forget_loader,
    }


def create_cifar_dataset(train=True, cifar_100=False):
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )
    if cifar_100:
        return datasets.CIFAR100(
            root="../data", train=train, download=True, transform=transform
        )
    return datasets.CIFAR10(
        root="../data", train=train, download=True, transform=transform
    )


def get_cifar_test_loaders(
    forget_classes=[3],
    batch_size=64,
    cifar_100=False,
    index_file=None,
):
    dataset = create_cifar_dataset(train=False, cifar_100=cifar_100)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, drop_last=False
    )
    retain_loader = None
    forget_loader = None
    if forget_classes is not None:
        forget_indices, retain_indices = get_indices(
            dataset,
            forget_classes,
            index_file,
        )

        forget_ds = torch.utils.data.Subset(dataset, forget_indices)
        retain_ds = torch.utils.data.Subset(dataset, retain_indices)

        retain_loader = torch.utils.data.DataLoader(
            retain_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        forget_loader = torch.utils.data.DataLoader(
            forget_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    return {
        "test_full_loader": loader,
        "test_retain_loader": retain_loader,
        "test_forget_loader": forget_loader,
    }


def build_dataloader_text(
    dataset,
    tokenizer,
    batch_size: int,
    shuffle: bool = True,
    num_workers: int = 0,
    drop_last: bool = False,
    **kwargs
):
    def dynamic_pad(features):
        features = tokenizer.pad(
            features,
            return_tensors="pt",
        )
        return features

    def collate_fn(batch):
        input_data = []
        output_data = []
        for inp, lbl in batch:
            input_data.append(inp)
            output_data.append(lbl)
        input_data = dynamic_pad(input_data)
        output_data = torch.tensor(output_data, dtype=torch.int64)

        return input_data, output_data

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collate_fn,
        drop_last=drop_last,
        num_workers=num_workers,
        **kwargs
    )
    return dataloader


class AGTextDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, train=True, max_length=128):
        super().__init__()
        self.dataset = load_dataset("ag_news")["train" if train else "test"]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        inputs = self.tokenizer(
            sample["text"], truncation=True, max_length=self.max_length
        )
        label = sample["label"]
        return inputs, label


def create_ag_news_dataset(tokenizer, train=True, max_length=128):
    return AGTextDataset(tokenizer=tokenizer, train=train, max_length=max_length)


def get_ag_news_test_loaders(
    tokenizer,
    forget_classes=[3],
    batch_size=64,
    index_file=None,
):
    dataset = create_ag_news_dataset(tokenizer=tokenizer, train=False)
    loader = build_dataloader_text(
        dataset, tokenizer, batch_size, shuffle=False, drop_last=False
    )
    retain_loader = None
    forget_loader = None
    if forget_classes is not None:
        forget_indices, retain_indices = get_indices(
            dataset,
            forget_classes,
            index_file,
        )

        forget_ds = torch.utils.data.Subset(dataset, forget_indices)
        retain_ds = torch.utils.data.Subset(dataset, retain_indices)

        retain_loader = build_dataloader_text(
            retain_ds, tokenizer, batch_size, shuffle=False, drop_last=False
        )
        forget_loader = build_dataloader_text(
            forget_ds, tokenizer, batch_size, shuffle=False, drop_last=False
        )

    return {
        "test_full_loader": loader,
        "test_retain_loader": retain_loader,
        "test_forget_loader": forget_loader,
    }


class ImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, preprocess_fn, train=True):
        super().__init__()
        self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256")[
            "train" if train else "val"
        ]
        self.preprocess_fn = preprocess_fn

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        inputs = sample["image"]
        inputs = self.preprocess_fn(inputs)
        label = sample["label"]
        return inputs, label


def create_imagenet_dataset(preprocess_fn, train=True):
    return ImageNetDataset(preprocess_fn=preprocess_fn, train=train)


def get_imagenet_test_loaders(
    preprocess_fn,
    forget_classes=[3],
    batch_size=64,
    index_file=None,
):
    dataset = create_imagenet_dataset(preprocess_fn=preprocess_fn, train=False)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, drop_last=False
    )
    retain_loader = None
    forget_loader = None
    if forget_classes is not None:
        forget_indices, retain_indices = get_indices(
            dataset,
            forget_classes,
            index_file,
        )

        forget_ds = torch.utils.data.Subset(dataset, forget_indices)
        retain_ds = torch.utils.data.Subset(dataset, retain_indices)

        retain_loader = torch.utils.data.DataLoader(
            retain_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        forget_loader = torch.utils.data.DataLoader(
            forget_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )

    return {
        "test_full_loader": loader,
        "test_retain_loader": retain_loader,
        "test_forget_loader": forget_loader,
    }


class DBPediaTextDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, train=True, max_length=128):
        super().__init__()
        self.dataset = load_dataset("fancyzhx/dbpedia_14")["train" if train else "test"]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        inputs = self.tokenizer(
            sample["content"], truncation=True, max_length=self.max_length
        )
        label = sample["label"]
        return inputs, label


def create_dbpedia_dataset(tokenizer, train=True, max_length=128):
    return DBPediaTextDataset(tokenizer=tokenizer, train=train, max_length=max_length)


def get_dbpedia_test_loaders(
    tokenizer,
    forget_classes=[3],
    batch_size=64,
    index_file=None,
):
    dataset = create_dbpedia_dataset(tokenizer=tokenizer, train=False)
    loader = build_dataloader_text(
        dataset, tokenizer, batch_size, shuffle=False, drop_last=False
    )
    retain_loader = None
    forget_loader = None
    if forget_classes is not None:
        forget_indices, retain_indices = get_indices(
            dataset,
            forget_classes,
            index_file,
        )

        forget_ds = torch.utils.data.Subset(dataset, forget_indices)
        retain_ds = torch.utils.data.Subset(dataset, retain_indices)
        retain_loader = build_dataloader_text(
            retain_ds, tokenizer, batch_size, shuffle=False, drop_last=False
        )
        forget_loader = build_dataloader_text(
            forget_ds, tokenizer, batch_size, shuffle=False, drop_last=False
        )

    return {
        "test_full_loader": loader,
        "test_retain_loader": retain_loader,
        "test_forget_loader": forget_loader,
    }


def split_dataset(dataset, val_per, seed=42):
    total_size = len(dataset)
    val_size = int(val_per * total_size)
    train_size = total_size - val_size
    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=generator
    )
    return train_dataset, val_dataset
