from typing import Any, Dict, List, Tuple

import numpy as np
import sklearn
import torchvision

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import path_config


def _is_between_clopen(x: np.ndarray, lower: float, upper: float) -> np.ndarray:
    return (lower <= x) & (x < upper)


def _make_splinters(
    num_obs: int, noise: float, num_classes: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate a (roughly) balanced data set according to the "splinter process"
    """
    assert num_classes >= 2, "Classification is not interesting unless num_classes >= 2"

    x_center = np.random.uniform(size=(num_obs, 2))
    x_center11 = 2 * (x_center - 0.5)
    x_center11_atan2 = np.arctan2(x_center11[:, 0], x_center11[:, 1]) + np.pi

    y = np.full((num_obs,), np.nan)

    for idx in range(num_classes):
        # idx = 0
        min_radians = (idx + 0) / num_classes * (2 * np.pi)
        max_radians = (idx + 1) / num_classes * (2 * np.pi)
        # print("[{}, {}]".format(min_radians, max_radians))
        rows = _is_between_clopen(x_center11_atan2, min_radians, max_radians)
        y[rows] = idx

    x = x_center + noise * np.random.normal(size=(num_obs, 2))
    assert np.all(np.isfinite(x)) and np.all(np.isfinite(y))

    return x, y


def _make_checkerboard(num_obs: int,
                       noise: float,
                       num_classes: int) -> Tuple[np.ndarray, np.ndarray]:

    nr = 3
    nc = 2
    dims = np.array([nr, nc])
    total_squares = np.prod(dims)

    # stride = 3
    dim = 2
    low = -1 * np.ones((dim, ))
    high = +1 * np.ones((dim, ))

    diameter = high - low
    height = diameter[0]
    width = diameter[1]

    x_center = np.random.uniform(low=low,
                                 high=high,
                                 size=(num_obs, 2))

    classes = np.reshape(np.arange(total_squares), (nr, nc)) % num_classes
    y = np.full((num_obs,), np.nan)
    for c_idx in range(nc):
        # c_idx = 0
        c_l = low[0] + width * (c_idx + 0) / nc
        c_u = low[0] + width * (c_idx + 1) / nc
        for r_idx in range(nr):
            # r_idx = 0
            r_l = low[1] + height * (r_idx + 0) / nr
            r_u = low[1] + height * (r_idx + 1) / nr

            # print("[{:+.3f}, {:+.3f}] x [{:+.3f}, {:+.3f}]".format(r_l, r_u, c_l, c_u))

            is_row_r = _is_between_clopen(x_center[:, 0], c_l, c_u)
            is_row_c = _is_between_clopen(x_center[:, 1], r_l, r_u)
            rows = is_row_r & is_row_c
            # print(np.mean(rows))
            y[rows] = classes[r_idx, c_idx]

    x = x_center + noise * np.random.normal(size=(num_obs, dim))

    assert np.all(np.isfinite(x)) and np.all(np.isfinite(y))

    if False:
        import matplotlib.pyplot as plt
        # plt.plot(x_center[:, 0], x_center[:, 1], ".")

        for idx in range(num_classes):
            # idx = 0
            idx_row = (idx == y)
            plt.plot(x[idx_row, 0],
                     x[idx_row, 1], ".")

    return x, y


def _make_spheres(num_obs: int,
                  noise: float,
                  num_classes: int,
                  dim: int) -> Tuple[np.ndarray, np.ndarray]:

    low = -1 * np.ones((dim, ))
    high = +1 * np.ones((dim, ))

    x_center = np.random.uniform(low=low,
                                 high=high,
                                 size=(num_obs, dim))
    radii = np.linalg.norm(x_center, axis=1)
    max_radius = np.linalg.norm(high - low) / 2
    assert np.all(radii <= max_radius)

    qs = np.linspace(0, 1, num_classes + 1)

    quantiles = np.quantile(radii, qs)
    quantiles[-1] = np.inf

    y = np.full((num_obs,), np.nan)
    for idx in range(num_classes):
        lower = quantiles[idx + 0]
        upper = quantiles[idx + 1]
        rows = _is_between_clopen(radii, lower, upper)
        y[rows] = idx

    x = x_center + noise * np.random.normal(size=(num_obs, dim))
    assert np.all(np.isfinite(x)) and np.all(np.isfinite(y))
    return x, y


def vec(x: np.ndarray) -> np.ndarray:
    return np.reshape(x, (-1, 1))


def _make_simple_image(num_obs: int,
                       noise: float,
                       num_classes: int,
                       dims: List[int]) -> Tuple[np.ndarray, np.ndarray]:
    n_channels = dims[0]
    assert 3 == n_channels, "This is only tested for n_channels == 3"
    height = dims[1]
    width = dims[2]

    ones_height = np.ones((height, 1))
    ones_width = np.ones((width, 1))
    height_coords = ones_height @ vec(np.linspace(-1, +1, width)).T
    width_coords = np.flipud(vec(np.linspace(-1, +1, height))) @ ones_width.T

    x_center11_atan2 = np.arctan2(height_coords, width_coords) + np.pi

    c = np.zeros((height, width))

    for idx in range(num_classes):
        # idx = 0
        min_radians = (idx + 0) / num_classes * (2 * np.pi)
        max_radians = (idx + 1) / num_classes * (2 * np.pi)
        # print("[{}, {}]".format(min_radians, max_radians))
        obs = _is_between_clopen(x_center11_atan2, min_radians, max_radians)
        c = c + idx * (obs.astype(float))

    y = np.floor(np.arange(num_obs) * num_classes / num_obs)
    x = np.full((num_obs, n_channels, height, width), np.nan)
    noise_data = noise * np.random.uniform(size=(num_obs, height, width))
    for idx in range(num_obs):
        x[idx, :, :, :] = (c == y[idx]) * .5 + noise_data[idx, :, :]

    row_permutation = np.random.permutation(num_obs)
    x = x[row_permutation, :, :]
    y = y[row_permutation]

    assert np.all(np.isfinite(x)) and np.all(np.isfinite(y))
    return x, y


def generate_dataset(data_par: Dict[str, Any],
                     is_train: bool) -> Tuple[np.ndarray, np.ndarray]:
    generating_function_name = data_par["generating_function_name"]
    n_samples = data_par["n_samples"]
    noise = data_par["noise"]
    num_classes = data_par["num_classes"]
    input_dims = data_par["input_dims"]

    if "mnist" == generating_function_name:
        paths = path_config.get_paths()
        data_path = paths["cached_datasets"]

        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        dataset = torchvision.datasets.MNIST(
            data_path,
            train=is_train,
            download=True,
            transform=transform)
        x_torch = dataset.data
        y_torch = dataset.targets
        x = x_torch.detach().numpy()
        y = y_torch.detach().numpy()

        mean = 0.1307
        std = 0.3081
        x = x / 255
        x = x - mean
        x = x / std
        x = np.reshape(x, (x.shape[0], -1))
    else:
        if "simple_image" == generating_function_name:
            x, y = _make_simple_image(n_samples, noise, num_classes, input_dims)
        elif "checkerboard" == generating_function_name:
            x, y = _make_checkerboard(n_samples, noise, num_classes)
        elif "spheres" == generating_function_name:
            x, y = _make_spheres(n_samples, noise, num_classes, input_dims[0])
        elif "splinters" == generating_function_name:
            x, y = _make_splinters(n_samples, noise, num_classes)
        elif "moons" == generating_function_name:
            x, y = sklearn.datasets.make_moons(n_samples, noise=noise)
        else:
            raise ValueError("Unknown data generating function: {}".format(generating_function_name))
    return x, y


if __name__ == "__main__":
    num_obs = 1000
    num_classes = 3
    noise = 0.1
    dim = 3

    x, y = _make_spheres(num_obs, noise, num_classes, dim)

    if False:
        dim = 3
        x, y = _make_spheres(num_obs, noise, num_classes, dim)

        import plotting
        COLORS = plotting.COLORS

        x_vals = x[:, 0]
        y_vals = x[:, 1]
        z_vals = x[:, 2]

        fig = plt.figure()
        ax = Axes3D(fig)
        c = [COLORS[int(_)] for _ in y]
        ax.scatter(x_vals, y_vals, z_vals, c=c)
