import random

import jax
import jax.numpy as jnp

import numpy as np
import numpy.random as npr

import torch
from torchvision import transforms
from torchvision.datasets import MNIST, USPS

from collections import namedtuple

from ott.core import problems
from ott.geometry.pointcloud import PointCloud

from dataclasses import dataclass

from PIL import Image

from . import utils

import os
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))

PairData = namedtuple('PairData', 'a b')
PairDataImgs = namedtuple('PairDataImgs', 'a b afull bfull')

@dataclass
class MNISTPairSampler:
    train: bool = True
    batch_size: int = 128
    epsilon: float = 1e-2
    debug: bool = False

    def __post_init__(self):
        dataset = MNIST(
            '/tmp/mnist/',
            download=True,
            train=self.train,
        )
        data = dataset.data
        data = jnp.float64(data)/255.
        data = data.reshape(-1, 784)
        data = data/data.sum(axis=1, keepdims=True)
        self.data = data

        """
        x_grid = []
        for i in jnp.linspace(1, 0, num=28):
            for j in jnp.linspace(0, 1, num=28):
                x_grid.append([j, i])
        """

        x_grid = jnp.linspace(0, 1, num=784).reshape(784, 1)

        x_grid = jnp.array(x_grid)
        self.geom = PointCloud(x=x_grid, y=x_grid, epsilon=self.epsilon) #, online=True)

        @jax.jit
        def _sample(key):
            k1, k2, key = jax.random.split(key, num=3)
            I = jax.random.randint(k1, shape=[self.batch_size], minval=0, maxval=len(data))
            J = jax.random.randint(k2, shape=[self.batch_size], minval=0, maxval=len(data))
            a = data[I]
            b = data[J]
            return PairData(a, b)
        self._sample = _sample

        if self.debug:
            key = jax.random.PRNGKey(0)
            self._debug_data = self._sample(key)


    def __call__(self, key):
        if self.debug:
            return self._debug_data
        else:
            return self._sample(key)

@dataclass
class WorldPairSampler:
    batch_size: int = 128
    epsilon: float = 1e-3
    population_fname: str = SCRIPT_DIR + '/../data/pop-15min.tif'
    supply_bernoulli_p: float = 0.5
    n_demand: int = 10000
    n_supply: int = 100
    debug: bool = False

    def __post_init__(self):
        import rasterio
        # Using 2020 Tiff data at 15-minute resolution from:
        # https://sedac.ciesin.columbia.edu/data/set/gpw-v4-population-density-adjusted-to-2015-unwpp-country-totals-rev11/data-download#
        src = rasterio.open(self.population_fname)

        # Population
        P = src.read(1)
        P[P < 0] = 0.
        Pflat = P.ravel()
        Pflat = Pflat / Pflat.max() # For numerical stability
        Pflat = Pflat / Pflat.sum()
        self.P = P
        self.Pflat = Pflat

        # Uniform over ~landmass
        Uflat = Pflat.copy()
        Uflat[Uflat > 0] = 1.
        Uflat /= Uflat.sum()
        self.Uflat = Uflat

        # Sample spherical and euclidean locations from p
        def sample(p, num_samples, seed=0):
            npr.seed(seed)
            sample_Is = npr.choice(len(p), p=p, size=num_samples)
            samples_theta = P.shape[0] - sample_Is / P.shape[1]
            samples_theta = (samples_theta / P.shape[0]) * np.pi
            samples_phi = sample_Is % P.shape[1]
            samples_phi = (samples_phi / P.shape[1]) * 2 * np.pi - np.pi
            samples_spherical = np.vstack((samples_phi, samples_theta)).T
            samples_euclidean = utils.spherical_to_euclidean(samples_spherical)
            return samples_spherical, samples_euclidean

        self.demand_locs_spherical, self.demand_locs_euclidean = sample(Pflat, self.n_demand, seed=0)
        self.supply_locs_spherical, self.supply_locs_euclidean = sample(Uflat, self.n_supply, seed=1)

        self.geom = PointCloud(
            x=self.supply_locs_euclidean, y=self.demand_locs_euclidean,
            epsilon=self.epsilon, cost_fn=utils.SphereDist())

        @jax.jit
        def _sample(key):
            k1, k2, k3, key = jax.random.split(key, num=4)

            demand_probs = jax.random.uniform(k1, [self.batch_size, self.n_demand])
            demand_probs /= demand_probs.sum(axis=1, keepdims=True)

            mask = jax.random.bernoulli(
                k2, p=self.supply_bernoulli_p, shape=[self.batch_size, self.n_supply])
            supply_probs = mask * jax.random.uniform(k3, [self.batch_size, self.n_supply])
            supply_probs /= supply_probs.sum(axis=1, keepdims=True)

            return PairData(supply_probs, demand_probs)

        self._sample = _sample

        if self.debug:
            key = jax.random.PRNGKey(0)
            self._debug_data = self._sample(key)


    def __call__(self, key):
        if self.debug:
            return self._debug_data
        else:
            return self._sample(key)


class ImageSampler:
    mean = jnp.array([0.485, 0.456, 0.406])
    std = jnp.array([0.229, 0.224, 0.225])

    def __init__(self, image_path, square_size=224, num_rgb_sample=None, key=None):
        self.path = image_path
        self.image = Image.open(image_path).convert('RGB')
        self.flat_norm_image = (jnp.asarray(self.image).transpose(2, 0, 1).reshape(3, -1) / 255.).T

        if num_rgb_sample is not None:
            I = jax.random.choice(key, len(self.flat_norm_image), shape=[num_rgb_sample])
            self.flat_norm_image = self.flat_norm_image[I]

        image_square = self.image.resize((square_size, square_size))
        self.image_square = self.normalize_image(jnp.asarray(image_square) / 255.)

    def normalize_image(self, image):
        return (image - self.mean) / self.std

    def unnormalize_image(self, image):
        return image * self.std + self.mean


    def sample(self, key, batch_size):
        I = jax.random.choice(key, len(self.flat_norm_image), shape=[batch_size])
        batch = self.flat_norm_image[I]
        return batch


class ImagePairSampler:
    def __init__(self, image_paths, num_rgb_sample=None, key=None):
        samplers = []
        for path in image_paths:
            if key is not None:
                key, = jax.random.split(key, 1)
            samplers.append(ImageSampler(path, num_rgb_sample=num_rgb_sample, key=key))
        self.samplers = samplers

    def sample_image_pair(self, val_pairs):
        X_sampler, Y_sampler = random.sample(self.samplers, 2)
        if val_pairs is not None:
            # Sample until it's not a validation pair
            if (X_sampler.path, Y_sampler.path) in val_pairs:
                return self.sample_image_pair(val_pairs)
        return X_sampler, Y_sampler

    def sample_image_pair_batch(self, batch_size=1, val_pairs=None):
        # TODO: Could clean all these names up, maybe with something like:
        # ImagePair = namedtuple('ImagePair', 'X_square X_flat Y_square Y_flat')
        X_samplers, Y_samplers, X_squares, Y_squares, X_fulls, Y_fulls = [], [], [], [], [], []
        for i in range(batch_size):
            X_sampler, Y_sampler = self.sample_image_pair(val_pairs)
            X_samplers.append(X_sampler)
            Y_samplers.append(Y_sampler)
            X_squares.append(X_sampler.image_square)
            Y_squares.append(Y_sampler.image_square)
            X_fulls.append(X_sampler.flat_norm_image)
            Y_fulls.append(Y_sampler.flat_norm_image)
        X_squares = jnp.stack(X_squares)
        Y_squares = jnp.stack(Y_squares)
        X_fulls = jnp.stack(X_fulls)
        Y_fulls = jnp.stack(Y_fulls)
        return X_samplers, Y_samplers, X_squares, Y_squares, X_fulls, Y_fulls

@dataclass
class USPSPairSampler:
    train: bool = True
    batch_size: int = 128
    epsilon: float = 1e-2
    debug: bool = False
    reshape: bool = False

    def __post_init__(self):

        if self.reshape:
            dataset = np.load(SCRIPT_DIR+'/../data/usps28.npy')

            if self.train:
                data = dataset[:7291,:]
            else:
                data = dataset[7291:, :]
        else:
            dataset = USPS(
                '/tmp/usps/',
                download=True,
                train=self.train,
            )
            data = dataset/255.

        data = jnp.float64(data)

        if data.ndim > 2:
            dim = data.shape[1] * data.shape[2]
        else:
            dim = data.shape[1]
        data = data.reshape(-1, dim)
        data = data / data.sum(axis=1, keepdims=True)
        self.data = data

        #x_grid = []
        #for i in jnp.linspace(1, 0, num=28):
        #    for j in jnp.linspace(0, 1, num=28):
        #        x_grid.append([j, i])

        x_grid = jnp.linspace(0, 1, num=dim).reshape(dim, 1)
        x_grid = jnp.array(x_grid)
        self.geom = PointCloud(x=x_grid, y=x_grid, epsilon=self.epsilon) #, online=True)

        @jax.jit
        def _sample(key):
            k1, k2, key = jax.random.split(key, num=3)
            I = jax.random.randint(k1, shape=[self.batch_size], minval=0, maxval=len(data))
            J = jax.random.randint(k2, shape=[self.batch_size], minval=0, maxval=len(data))
            a = data[I]
            b = data[J]

            return PairData(a, b)
        self._sample = _sample

        if self.debug:
            key = jax.random.PRNGKey(0)
            self._debug_data = self._sample(key)


    def __call__(self, key):
        if self.debug:
            return self._debug_data
        else:
            return self._sample(key)


@dataclass
class DoodlePairSampler:
    train: bool = True
    batch_size: int = 128
    epsilon: float = 1e-2
    debug: bool = False

    def __post_init__(self):

        data = []
        categories = os.listdir(SCRIPT_DIR + '/../data/doodles/')

        for category in categories:
            category_data = np.load(SCRIPT_DIR + '/../data/doodles/' + category)

            if self.train:
                data.append(category_data[:100000, :])
            else:
                data.append(category_data[100000:, :])

        data = np.concatenate(data)
        dim = data.shape[1]
        data = jnp.float64(data) / 255.
        data = data.reshape(-1, dim)
        data = data / data.sum(axis=1, keepdims=True)
        self.data = data

        # x_grid = []
        # for i in jnp.linspace(1, 0, num=28):
        #    for j in jnp.linspace(0, 1, num=28):
        #        x_grid.append([j, i])

        x_grid = jnp.linspace(0, 1, num=dim).reshape(dim, 1)
        x_grid = jnp.array(x_grid)
        self.geom = PointCloud(x=x_grid, y=x_grid, epsilon=self.epsilon) #, online=True)

        @jax.jit
        def _sample(key):
            k1, k2, key = jax.random.split(key, num=3)
            I = jax.random.randint(k1, shape=[self.batch_size], minval=0, maxval=len(data))
            J = jax.random.randint(k2, shape=[self.batch_size], minval=0, maxval=len(data))
            a = data[I]
            b = data[J]

            return PairData(a, b)

        self._sample = _sample

        if self.debug:
            key = jax.random.PRNGKey(0)
            self._debug_data = self._sample(key)

    def __call__(self, key):
        if self.debug:
            return self._debug_data
        else:
            return self._sample(key)


@dataclass
class RandomSampler:
    batch_size: int = 128
    epsilon: float = 1e-2
    type: str = 'uniform'
    dim: int = 784
    debug: bool = False

    def __post_init__(self):

        x = jnp.linspace(0, 1, num=self.dim)
        x_grid = jnp.array(x.reshape(self.dim, 1))
        self.geom = PointCloud(x=x_grid, y=x_grid, epsilon=self.epsilon) #, online=True)

        @jax.jit
        def _sample(key):

            if self.type == 'gauss':

                k1, k2, k3, k4, key = jax.random.split(key, num=5)

                mean_1 = jax.random.uniform(k1, minval=.3, maxval=.7, shape=(self.batch_size,))
                std_1 = jax.random.uniform(k2, minval=.1, maxval=.3, shape=(self.batch_size,))

                a = jnp.asarray([jnp.exp(-(x - mean) ** 2 / (2 * std ** 2)) for mean, std in zip(mean_1, std_1)])
                a = a / a.sum(axis=1, keepdims=1)

                mean_2 = jax.random.uniform(k3, minval=.3, maxval=.7, shape=(self.batch_size,))
                std_2 = jax.random.uniform(k4, minval=.1, maxval=.3, shape=(self.batch_size,))

                b = jnp.asarray([jnp.exp(-(x - mean) ** 2 / (2 * std ** 2)) for mean, std in zip(mean_2, std_2)])
                b = b / b.sum(axis=1, keepdims=1)

            elif self.type == 'uniform':

                k1, k2, key = jax.random.split(key, num=3)

                a = jax.random.uniform(k1, minval=0, maxval=1, shape=(self.batch_size, 784))
                a = jnp.where(a > .95, a, 0)
                a = a / a.sum(axis=1, keepdims=1)

                b = jax.random.uniform(k2, minval=0, maxval=1, shape=(self.batch_size, 784))
                b = jnp.where(b > .95, b, 0)
                b = b / b.sum(axis=1, keepdims=1)
            else:
                assert False

            return PairData(a, b)

        self._sample = _sample

        if self.debug:
            key = jax.random.PRNGKey(0)
            self._debug_data = self._sample(key)

    def __call__(self, key):
        if self.debug:
            return self._debug_data
        else:
            return self._sample(key)

