import os
import zipfile

import pytorch_lightning as pl
import requests
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
from tqdm import tqdm
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch


class CIFAR10Data(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.hparams = args
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2471, 0.2435, 0.2616)

    def download_weights():
        url = (
            "https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip"
        )

        # Streaming, so we can iterate over the response.
        r = requests.get(url, stream=True)

        # Total size in Mebibyte
        total_size = int(r.headers.get("content-length", 0))
        block_size = 2 ** 20  # Mebibyte
        t = tqdm(total=total_size, unit="MiB", unit_scale=True)

        with open("state_dicts.zip", "wb") as f:
            for data in r.iter_content(block_size):
                t.update(len(data))
                f.write(data)
        t.close()

        if total_size != 0 and t.n != total_size:
            raise Exception("Error, something went wrong")

        print("Download successful. Unzipping file...")
        path_to_zip_file = os.path.join(os.getcwd(), "state_dicts.zip")
        directory_to_extract_to = os.path.join(os.getcwd(), "models")
        with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
            zip_ref.extractall(directory_to_extract_to)
            print("Unzip file successful!")

    def train_dataloader(self):
        transform = T.Compose(
            [
                T.RandomCrop(32, padding=4),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR10(root=self.hparams.data_dir,
                          train=True, transform=transform)
        dataloader = DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
        )
        return dataloader

    def val_dataloader(self):
        transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR10(root=self.hparams.data_dir,
                          train=False, transform=transform)
        dataloader = DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )
        return dataloader

    def test_dataloader(self):
        return self.val_dataloader()


class CIFAR100Data(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.hparams = args
        self.mean = [0.5070, 0.4865, 0.4409]
        self.std = [0.2673, 0.2564, 0.2761]

    def train_dataloader(self):
        transform = T.Compose(
            [
                T.RandomCrop(32, padding=4),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR100(root=self.hparams.data_dir,
                           train=True, transform=transform)
        dataloader = DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
        )
        return dataloader

    def val_dataloader(self):
        transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR100(root=self.hparams.data_dir,
                           train=False, transform=transform)
        dataloader = DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )
        return dataloader

    def test_dataloader(self):
        return self.val_dataloader()


class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)
        self.setup()

    def setup(self, stage=None):

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std)
        ])
        self.val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std)
        ])

        traindir = os.path.join(self.args.data_dir, 'train')
        valdir = os.path.join(self.args.data_dir, 'val')
        imagenet_train = datasets.ImageFolder(
            root=traindir,
            transform=self.train_transform)

        imagenet_val = datasets.ImageFolder(
            root=valdir,
            transform=self.val_transform)

        self.train_dataset = imagenet_train
        self.val_dataset = imagenet_val
        self.test_dataset = imagenet_val

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.args.batch_size,
                          num_workers=self.args.num_workers,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.args.batch_size,
                          num_workers=self.args.num_workers,
                          pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.args.batch_size,
                          num_workers=self.args.num_workers,
                          pin_memory=True)


class DataLoaderCreator:
    def __init__(self, opt):
        self.opt = opt
        self.normalize = None
        self.mean = None
        self.std = None
        self.val_transform = None

    def create_data_loader(self):
        if self.opt.dataset == 'cifar10':
            train_dataset = self._create_cifar10_dataset()
        elif self.opt.dataset == 'cifar100':
            train_dataset = self._create_cifar100_dataset()
        elif self.opt.dataset == 'path':
            train_dataset = self._create_image_folder_dataset()
        else:
            raise ValueError(
                'dataset not supported: {}'.format(self.opt.dataset))

        data_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=256, shuffle=False,
            num_workers=8, pin_memory=True)

        return data_loader

    def _create_cifar10_dataset(self):
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)

        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            self.normalize,
        ])

        train_dataset = datasets.CIFAR10(root=self.opt.data_folder,
                                         transform=self.val_transform,
                                         download=True)

        return train_dataset

    def _create_cifar100_dataset(self):
        self.mean = (0.5071, 0.4867, 0.4408)
        self.std = (0.2675, 0.2565, 0.2761)
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)

        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            self.normalize,
        ])

        train_dataset = datasets.CIFAR100(root=self.opt.data_folder,
                                          transform=self.val_transform,
                                          download=True)

        return train_dataset

    def _create_image_folder_dataset(self):
        self.mean = eval(self.opt.mean)
        self.std = eval(self.opt.std)
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)

        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            self.normalize,
        ])

        train_dataset = datasets.ImageFolder(root=self.opt.data_folder,
                                             transform=self.val_transform)

        return train_dataset
