import os
import urllib.request
import zipfile
from functools import partial
from pathlib import Path

import jax
import numpy as onp
import pandas as pd
from jax import jit
from jax import numpy as np
from jax import random
from torchvision import datasets, transforms


# @jax.jit
def random_crop(batch, key):
    ADD = 4
    B, H, W, C = batch.shape
    paddings = onp.array([[0, 0], [ADD, ADD], [ADD, ADD], [0, 0]])
    batch = np.pad(batch, paddings, "reflect")
    a, b = jax.random.randint(key, (2,), minval=0, maxval=2 * ADD)
    # a, b = jax.random.randint(key, (2,), minval=0, maxval=2 * ADD + 1)
    return batch[:, a : a + H, b : b + W, :]


@jax.jit
def random_flip(batch, key):
    r = jax.random.randint(key, (), minval=0, maxval=2)
    return jax.lax.cond(r, lambda: batch, lambda: np.flip(batch, axis=2))


def random_flip_crop(batch, key):
    k1, k2 = random.split(key)
    return random_flip(random_crop(batch, k1), k2)


class BatchLoader:
    def __init__(self, data, targets, key, batch_size, data_aug=None):
        self.data = data
        self.targets = targets
        self.batch_size = batch_size
        self.key = key
        self.data_aug = data_aug
        self.num_data = data.shape[0]
        self.num_data_per_epoch = (data.shape[0] // batch_size) * batch_size
        self.equal_batches = self.num_data_per_epoch // self.batch_size
        self.batches = self.equal_batches + (self.num_data != self.num_data_per_epoch)

    def __iter__(self):
        if self.key is None:
            permutation = np.arange(self.num_data)
        else:
            self.key, cur_key = random.split(self.key)
            permutation = random.permutation(cur_key, self.num_data)
        if self.data_aug is None:
            for batch in permutation[: self.num_data_per_epoch].split(
                self.equal_batches
            ):
                yield self.data.take(batch, axis=0), self.targets.take(batch, axis=0)
            if self.num_data != self.num_data_per_epoch:
                remainder = permutation[self.num_data_per_epoch :]
                yield self.data.take(remainder, axis=0), self.targets.take(
                    remainder, axis=0
                )
        else:
            assert self.key is not None
            for batch in permutation[: self.num_data_per_epoch].split(
                self.equal_batches
            ):
                self.key, k = random.split(self.key)
                yield self.data_aug(
                    self.data.take(batch, axis=0), k
                ), self.targets.take(batch)
            if self.num_data != self.num_data_per_epoch:
                self.key, k = random.split(self.key)
                remainder = permutation[self.num_data_per_epoch :]
                yield self.data_aug(
                    self.data.take(remainder, axis=0), k
                ), self.targets.take(remainder)

    def __len__(self):
        return self.batches


def normalize(X, axis=(0, 1, 2)):
    # Shape: (B, H, W, C)
    mean = X.mean(axis=axis)
    std = X.std(axis=axis)
    return (X - mean) / std


class ImageDataset:
    def __init__(
        self,
        dataset,
        width,
        height,
        channels,
        data_location,
        data_limit=None,
        batch_size=None,
        key=None,
        data_transform=None,
        filter=None,
        randomise=False,
        classes=None,
        include_flip=False,
        data_aug=None,
    ):
        mean = (0.5,) * 3
        std = (0.5,) * 3
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
        )
        train_dataset = dataset(
            data_location, train=True, download=True, transform=transform
        )
        self.train_data = np.array(train_dataset.data, dtype=np.float32)
        self.train_targets = np.array(train_dataset.targets)
        if data_limit:
            self.train_data = self.train_data[:data_limit]
        self.train_data = normalize(self.train_data)
        if include_flip:
            self.train_data = np.concatenate(
                [self.train_data, np.flip(self.train_data, axis=2)]
            )
            self.train_targets = np.concatenate(
                [self.train_targets, self.train_targets]
            )
        if filter is not None:
            self.train_data, self.train_targets = filter(
                self.train_data, self.train_targets
            )
        if data_transform is not None:
            self.train_data = data_transform(self.train_data)
        test_dataset = dataset(
            data_location, train=False, download=True, transform=transform
        )
        self.test_data = np.array(test_dataset.data, dtype=np.float32)
        self.test_targets = np.array(test_dataset.targets)
        self.test_data = normalize(self.test_data)
        if filter is not None:
            self.test_data, self.test_targets = filter(
                self.test_data, self.test_targets
            )
        if data_transform is not None:
            self.test_data = data_transform(self.test_data)
        if batch_size is not None:
            k1 = k2 = None
            if key is not None and randomise:
                k1, k2 = random.split(key)
            self.train_loader = BatchLoader(
                self.train_data, self.train_targets, k1, batch_size, data_aug=data_aug
            )
            self.test_loader = BatchLoader(
                self.test_data, self.test_targets, k2, batch_size
            )

        if classes is None:
            classes = train_dataset.classes
        self.name = dataset.__name__
        self.num_outputs = len(classes)
        self.classes = classes
        self.width = width
        self.height = height
        self.channels = channels
        self.batch_size = batch_size
        self.input_shape = (
            self.batch_size if self.batch_size else len(self.train_data),
            self.width,
            self.height,
            self.channels,
        )
        print(f"Dataset {self.name}:")
        print(
            f"{self.width} x {self.height} x {self.channels} images with {self.num_outputs} classes"
        )
        print(
            f"{len(self.train_data)} train points and {len(self.test_data)} test points with a batch size of {self.batch_size}."
        )

    def train_accuracy(self, params, forward, batched=False):
        if batched:
            return self.batch_accuracy(self.train_loader, params, forward)
        return self.accuracy(self.train_data, self.train_targets, params, forward)

    def test_accuracy(self, params, forward, batched=False):
        if batched:
            return self.batch_accuracy(self.test_loader, params, forward)
        return self.accuracy(self.test_data, self.test_targets, params, forward)

    @partial(jit, static_argnums=(0, 4))
    def accuracy(self, data, targets, params, forward):
        return np.mean(np.argmax(forward(params, data), axis=1) == targets)

    def batch_accuracy(self, loader, params, forward):
        acc_total = 0
        for data, target in loader:
            acc_total += self.accuracy(data, target, params, forward)
        return acc_total / len(loader)


class MNIST(ImageDataset):
    def __init__(self, data_location, **kwargs):
        super().__init__(
            datasets.MNIST,
            data_location=data_location,
            width=28,
            height=28,
            channels=1,
            data_transform=lambda x: np.expand_dims(x, axis=3),
            **kwargs,
        )


class FashionMNIST(ImageDataset):
    def __init__(self, data_location, **kwargs):
        super().__init__(
            datasets.FashionMNIST,
            data_location=data_location,
            width=28,
            height=28,
            channels=1,
            data_transform=lambda x: np.expand_dims(x, axis=3),
            **kwargs,
        )

class CIFAR10(ImageDataset):
    def __init__(self, data_location, **kwargs):
        super().__init__(
            datasets.CIFAR10,
            data_location=data_location,
            width=32,
            height=32,
            channels=3,
            **kwargs
        )

class CIFAR100(ImageDataset):
    def __init__(self, data_location, **kwargs):
        super().__init__(
            datasets.CIFAR100,
            data_location=data_location,
            width=32,
            height=32,
            channels=3,
            **kwargs
        )