import array
import gzip
import os
from os import path
import struct
import urllib.request
from typing import Tuple

import numpy as np
from jax._src.random import PRNGKey
from omegaconf import DictConfig
from jax import random


def mnist_data_stream(cfg, rng: PRNGKey):
    (
        train_images,
        train_labels,
        train_sensitives,
        test_images,
        test_labels,
        test_sensitives,
    ) = mnist(cfg=cfg, permute_train=False)
    cfg.dataset.num_train_samples = train_images.shape[0]

    # split rng
    train_rng, hist_estimation_rng = random.split(rng)
    # reshape as an image
    train_images = np.reshape(train_images, (-1, 1, 28, 28, 1))
    test_images = np.reshape(test_images, (-1, 1, 28, 28, 1))
    test_stream = (test_images, test_labels, test_sensitives)

    def data_train_stream():
        nonlocal train_rng
        num_train = train_images.shape[0]
        batch_size = cfg.training_params.batch_size
        num_complete_batches, leftover = divmod(num_train, batch_size)
        num_batches = num_complete_batches + bool(leftover)
        train_rng, train_perm_rng = random.split(train_rng)
        while True:
            perm = random.permutation(key=train_perm_rng, x=num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[
                    batch_idx
                ], train_sensitives[batch_idx]

    def data_regularizer_stream():
        nonlocal hist_estimation_rng
        num_train = train_images.shape[0]
        batch_size = cfg.algorithm.hist_estimation_batch_size
        num_complete_batches, leftover = divmod(num_train, batch_size)
        num_batches = num_complete_batches + bool(leftover)
        hist_estimation_rng, hist_estimation_perm_rng = random.split(
            hist_estimation_rng
        )
        while True:
            perm = random.permutation(key=hist_estimation_perm_rng, x=num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[
                    batch_idx
                ], train_sensitives[batch_idx]

    return (data_train_stream(), data_regularizer_stream(), test_stream)


def _download(data_path, url, filename):
    """Download a url to a file in the JAX data temp directory."""
    if not path.exists(data_path):
        os.makedirs(data_path)

    out_file = path.join(data_path, filename)
    if not path.isfile(out_file):
        urllib.request.urlretrieve(url, out_file)
        print(f"downloaded {url} to {data_path}")


def _partial_flatten(x: np.ndarray) -> np.ndarray:
    """Flatten all but the first dimension of an ndarray."""
    return np.reshape(x, (x.shape[0], -1))


def _one_hot(x: np.ndarray, k: int, dtype=np.float32) -> np.ndarray:
    """Create a one-hot encoding of x of size k."""
    return np.array(x[:, None] == np.arange(k), dtype)


def mnist_raw(cfg: DictConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    def parse_labels(filename):
        with gzip.open(filename, "rb") as fh:
            _ = struct.unpack(">II", fh.read(8))
            return np.array(array.array("B", fh.read()), dtype=np.uint8)

    def parse_images(filename: str) -> np.ndarray:
        with gzip.open(filename, "rb") as fh:
            _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
            return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(
                num_data, rows, cols
            )

    for filename in [
        "train-images-idx3-ubyte.gz",
        "train-labels-idx1-ubyte.gz",
        "t10k-images-idx3-ubyte.gz",
        "t10k-labels-idx1-ubyte.gz",
    ]:
        _download(
            url=cfg.dataset.base_url + filename,
            data_path=cfg.dataset.data_path,
            filename=filename,
        )

    train_images = parse_images(
        path.join(cfg.dataset.data_path, "train-images-idx3-ubyte.gz")
    )
    train_labels = parse_labels(
        path.join(cfg.dataset.data_path, "train-labels-idx1-ubyte.gz")
    )
    test_images = parse_images(
        path.join(cfg.dataset.data_path, "t10k-images-idx3-ubyte.gz")
    )
    test_labels = parse_labels(
        path.join(cfg.dataset.data_path, "t10k-labels-idx1-ubyte.gz")
    )

    return train_images, train_labels, test_images, test_labels


def mnist(cfg: DictConfig, permute_train=False):
    """Download, parse and process MNIST data to unit scale and one-hot labels."""
    train_images, train_labels, test_images, test_labels = mnist_raw(cfg=cfg)

    train_images = _partial_flatten(train_images) / np.float32(255.0)
    test_images = _partial_flatten(test_images) / np.float32(255.0)

    train_sensitives = np.where(
        train_labels == 2,
        np.random.rand(len(train_labels == 2)) > 0.2,
        np.where(
            train_labels == 8,
            np.random.rand(len(train_labels == 8)) > 0.8,
            np.random.rand(len(train_labels)) > 0.5,
        ),
    ).astype(int)

    test_sensitives = np.where(
        test_labels == 2,
        np.random.rand(len(test_labels == 2)) > 0.2,
        np.where(
            test_labels == 8,
            np.random.rand(len(test_labels == 8)) > 0.8,
            np.random.rand() > 0.5,
        ),
    ).astype(int)

    train_labels = _one_hot(train_labels, 10)
    test_labels = _one_hot(test_labels, 10)

    train_sensitives = _one_hot(train_sensitives, 2)
    test_sensitives = _one_hot(test_sensitives, 2)

    if permute_train:
        perm = np.random.RandomState(0).permutation(train_images.shape[0])
        train_images = train_images[perm]
        train_labels = train_labels[perm]

    return (
        train_images,
        train_labels,
        train_sensitives,
        test_images,
        test_labels,
        test_sensitives,
    )
