from typing import Sequence

import jax
from jax import lax
from jax import numpy as jnp
from jax import random


# https://github.com/google/jax/discussions/9213
def random_pad_crop(rng, im: jnp.ndarray, fill, padding: Sequence[int], size: Sequence[int]):
    """First dimension is batch."""
    assert len(padding) == len(size)
    # Pad dims 1, 2, ...
    im = pad(im, fill, padding, axis=tuple(range(1, len(padding) + 1)))
    im = random_crop(rng, im, size)
    return im


def pad(im: jnp.ndarray, fill, padding: Sequence[int], axis: Sequence[int]):
    """First dimension is batch."""
    assert len(padding) == len(axis)
    assert len(padding) <= len(im.shape)
    axis = _canonical_tuple(axis, im.ndim)
    padding_value = [(0, 0, 0)] * im.ndim
    for ax, p in zip(axis, padding):
        padding_value[ax] = (p, p, 0)
    return lax.pad(im, fill, padding_value)


def random_crop(rng, im: jnp.ndarray, size: Sequence[int]):
    """First dimension is batch."""
    batch_len, im_shape = im.shape[0], im.shape[1:]
    assert len(size) <= len(im_shape)
    rngs = random.split(rng, len(size))
    offsets = [
        jax.random.randint(rng, (batch_len,), minval=0, maxval=n - m + 1)
        for rng, m, n in zip(rngs, size, im_shape)
    ]
    return jax.vmap(crop, in_axes=(0, 0, None))(im, offsets, size)


def crop(im, offsets: Sequence[int], sizes: Sequence[int]):
    """No batch dimension."""
    assert len(offsets) == len(sizes)
    start_indices = [
        offsets[i] if i < len(offsets) else 0
        for i, _ in enumerate(im.shape)
    ]
    slice_sizes = [
        sizes[i] if i < len(offsets) else shape
        for i, shape in enumerate(im.shape)
    ]
    return lax.dynamic_slice(im, start_indices, slice_sizes)


def random_flip(rng, im, axis):
    """Supports batch dimension."""
    axis = _canonical_axis(axis, im.ndim)
    assert axis > 0
    batch_len = im.shape[0]
    cond = random.bernoulli(rng, shape=(batch_len,))
    return jax.vmap(cond_flip, in_axes=(0, 0, None))(cond, im, axis - 1)


def cond_flip(cond, im, axis):
    """No batch dimension."""
    return lax.select(cond, jnp.flip(im, axis=axis), im)


def _canonical_tuple(axis, ndim):
    axis = _to_tuple(axis)
    axis = tuple(_canonical_axis(ax, ndim) for ax in axis)
    if not len(set(axis)) == len(axis):
        raise ValueError('same axis specified twice', axis)
    return axis


def _to_tuple(x):
    try:
        return tuple(x)
    except TypeError:
        return (x,)


def _canonical_axis(axis, ndim):
    canon = axis + ndim if axis < 0 else axis
    if not (0 <= canon and canon < ndim):
        raise ValueError(f'axis invalid for ndim {ndim}', axis)
    return canon
