from copy import deepcopy
from pathlib import Path
from typing import Optional

import torch
from hydra.utils import call
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10, CIFAR100

from src.data.data_utils import split_subsets_train_val, split_dataset_train_val, add_attrs


class CIFARDataModule(LightningDataModule):
    """Standard CIFAR, train, val, test splits and transforms.
    >>> CIFARDataModule()  # doctest: +ELLIPSIS
    <...CIFAR_datamodule.CIFARDataModule object at ...>
    """

    name = "CIFAR"

    def __init__(
            self,
            split_function,
            num_classes: int = 10,
            data_dir: str = Path("/tmp"),
            val_split: float = 0.1,
            num_workers: int = 16,
            normalize: bool = False,
            seed: int = 42,
            batch_size: int = 32,
            num_clients: int = 3,
            fair_val: bool = False,
            *args,
            **kwargs,
    ):
        """
        Args:
            data_dir: where to save/load the data
            val_split: how many of the training images to use for the validation split
            num_workers: how many workers to use for loading data
            normalize: If true applies image normalize
            seed: starting seed for RNG.
            batch_size: desired batch size.
        """
        super().__init__(*args, **kwargs)

        self.data_dir = data_dir
        self.val_split = val_split
        self.num_workers = num_workers
        self.normalize = normalize
        self.seed = seed
        self.batch_size = batch_size
        self.num_clients = num_clients
        self.fair_val = fair_val
        self.split_function = split_function
        self.num_classes = num_classes
        assert num_classes in (10,11, 100), "Number of classes for CIFAR can be 10 or 100."
        if self.num_classes == 10 or self.num_classes == 11 :
            self.dataset = CIFAR10
            self.ds_mean = (0.49139968, 0.48215841, 0.44653091)
            self.ds_std = (0.24703223, 0.24348513, 0.26158784)
        elif self.num_classes == 100:
            self.dataset = CIFAR100
            self.ds_mean = (0.50707516, 0.48654887, 0.44091784)
            self.ds_std = (0.26733429, 0.25643846, 0.27615047)
        self.datasets_train: [Subset] = ...
        self.datasets_val: [Subset] = ...
        self.train_dataset: Dataset = ...
        self.val_dataset: Dataset = ...
        self.test_dataset: Dataset = ...

        self.current_client_idx = 0

    def prepare_data(self):
        """Saves CIFAR files to `data_dir`"""
        self.dataset(self.data_dir, train=True, download=True)
        self.dataset(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        """Split the train and valid dataset."""
        if stage == "fit":
            self.train_dataset = self.dataset(
                self.data_dir, train=True,
                download=False,
                transform=self.aug_transforms
            )
            self.val_dataset = self.dataset(
                self.data_dir, train=True,
                download=False,
                transform=self.default_transforms
            )
            self.train_dataset.targets = torch.Tensor(self.train_dataset.targets).to(torch.long)
            self.val_dataset.targets = torch.Tensor(self.val_dataset.targets).to(torch.long)

            if self.fair_val:
                train_subset, val_subset = split_dataset_train_val(
                    train_dataset=self.train_dataset,
                    val_split=self.val_split,
                    seed=self.seed,
                    val_dataset=self.val_dataset
                )
                self.datasets_train = call(self.split_function, dataset=train_subset)
                self.datasets_val = [deepcopy(val_subset) for _ in range(self.num_clients)]
                add_attrs(self.datasets_train, self.datasets_val)
            else:
                subsets = call(self.split_function, dataset=self.train_dataset)
                # results is # [train1, t2, ..., tn], [vval1, v2, ..., vn]
                self.datasets_train, self.datasets_val = split_subsets_train_val(
                    subsets, self.val_split, self.seed, val_dataset=self.val_dataset
                )

    def transfer_setup(self):
        self.train_dataset = self.dataset(
            self.data_dir, train=True,
            download=False,
            transform=self.aug_transforms
        )
        self.val_dataset = self.dataset(
            self.data_dir, train=True,
            download=False,
            transform=self.default_transforms
        )
        self.test_dataset = self.dataset(
            self.data_dir, train=False, download=False, transform=self.default_transforms
        )


    def next_client(self):
        self.current_client_idx += 1
        assert self.current_client_idx < self.num_clients, "Client number shouldn't excced seleced number of clients"

    def train_dataloader(self):
        # check this: https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html#multiple-training-dataloaders
        """CIFAR train set removes a subset to use for validation."""
        loader = DataLoader(
            self.datasets_train[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        """CIFAR val set uses a subset of the training set for validation."""
        loader = DataLoader(
            self.datasets_val[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    def test_dataloader(self):
        """CIFAR test set uses the test split."""
        dataset = self.dataset(self.data_dir, train=False, download=False, transform=self.default_transforms)
        dataset.targets = torch.Tensor(dataset.targets).to(torch.long)
        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    @property
    def default_transforms(self):
        cifar_transforms = [
            transform_lib.ToTensor(),

        ]
        if self.normalize:
            cifar_transforms.append(transform_lib.Normalize(mean=self.ds_mean,
                                                            std=self.ds_std))

        return transform_lib.Compose(cifar_transforms)

    @property
    def aug_transforms(self):
        cifar_transforms = [
            transform_lib.RandomCrop(32, padding=4),
            transform_lib.RandomHorizontalFlip(),
            transform_lib.ToTensor(),

        ]
        if self.normalize:
            cifar_transforms.append(transform_lib.Normalize(mean=self.ds_mean,
                                                            std=self.ds_std))

        return transform_lib.Compose(cifar_transforms)
