from functools import partial
from typing import Any, Tuple

import flax
import jax
import jax.numpy as jnp
import numpy as np


class JaxRNG(object):
    """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
    pure function.
    """

    @classmethod
    def from_seed(cls, seed):
        return cls(jax.random.PRNGKey(seed))

    def __init__(self, rng):
        self.rng = rng

    def __call__(self, keys=None):
        if keys is None:
            self.rng, split_rng = jax.random.split(self.rng)
            return split_rng
        elif isinstance(keys, int):
            split_rngs = jax.random.split(self.rng, num=keys + 1)
            self.rng = split_rngs[0]
            return tuple(split_rngs[1:])
        else:
            split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
            self.rng = split_rngs[0]
            return {key: val for key, val in zip(keys, split_rngs[1:])}


def wrap_function_with_rng(rng):
    """To be used as decorator, automatically bookkeep a RNG for the wrapped function."""

    def wrap_function(function):
        def wrapped(*args, **kwargs):
            nonlocal rng
            rng, split_rng = jax.random.split(rng)
            return function(split_rng, *args, **kwargs)

        return wrapped

    return wrap_function


def init_rng(seed):
    global jax_utils_rng
    jax_utils_rng = JaxRNG.from_seed(seed)


def next_rng(*args, **kwargs):
    global jax_utils_rng
    return jax_utils_rng(*args, **kwargs)


def get_metrics(metrics, unreplicate=False):
    if unreplicate:
        metrics = flax.jax_utils.unreplicate(metrics)
    metrics = jax.device_get(metrics)
    return {key: float(val) for key, val in metrics.items()}


def extend_and_repeat(tensor, axis, repeat):
    return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis)


def mse_loss(val, target):
    return jnp.mean(jnp.square(val - target))


def cross_entropy_loss(logits, labels, smoothing_factor=0.0):
    num_classes = logits.shape[-1]
    if labels.dtype == jnp.int32 or labels.dtype == jnp.int64:
        labels = jax.nn.one_hot(labels, num_classes)
    if smoothing_factor > 0.0:
        labels = labels * (1.0 - smoothing_factor) + smoothing_factor / num_classes
    logp = jax.nn.log_softmax(logits, axis=-1)
    return -jnp.mean(jnp.sum(logp * labels, axis=-1))


def accumulated_gradient(state, accumulated_grads, accumulated_steps, grads, apply_steps, apply_fn=None):
    accumulated_grads = jax.tree_map(
        lambda x, y: x + y,
        accumulated_grads,
        jax.tree_map(lambda x: x / float(apply_steps), grads),
    )
    accumulated_steps = (accumulated_steps + 1) % apply_steps
    if apply_fn is None:
        apply_fn = lambda s, g: s.apply_gradients(grads=g)
    state = jax.lax.cond(
        accumulated_steps == 0,
        lambda: apply_fn(state, accumulated_grads),
        lambda: state,
    )
    accumulated_grads = jax.lax.cond(
        accumulated_steps == 0,
        lambda: jax.tree_map(jnp.zeros_like, accumulated_grads),
        lambda: accumulated_grads,
    )
    return state, accumulated_grads, accumulated_steps


@partial(jax.pmap, axis_name="pmap", donate_argnums=0)
def sync_state_across_devices(state):
    i = jax.lax.axis_index("pmap")

    def select(x):
        return jax.lax.psum(jnp.where(i == 0, x, jnp.zeros_like(x)), "pmap")

    return jax.tree_map(select, state)


def get_random_bounding_box(
    image_shape: Tuple[int, int], lambda_cutmix: float, margin: float = 0.0
) -> Tuple[int, int, int, int]:
    ratio = np.sqrt(1 - lambda_cutmix)
    img_h, img_w = image_shape
    cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
    margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
    cy = np.random.randint(0 + margin_y, img_h - margin_y)
    cx = np.random.randint(0 + margin_x, img_w - margin_x)
    y_min = np.clip(cy - cut_h // 2, 0, img_h)
    y_max = np.clip(cy + cut_h // 2, 0, img_h)
    x_min = np.clip(cx - cut_w // 2, 0, img_w)
    x_max = np.clip(cx + cut_w // 2, 0, img_w)
    return y_min, y_max, x_min, x_max


def label_smoothing_fn(labels, smoothing_factor):
    num_classes = labels.shape[-1]
    labels = labels * (1.0 - smoothing_factor) + smoothing_factor / num_classes
    return labels


def mixup_cutmix(
    images: jnp.ndarray,
    labels: jnp.ndarray,
    rng: Any,
    num_classes: int,
    mixup_alpha: float = 1.0,
    cutmix_alpha: float = 0.0,
    switch_prob: float = 0.5,
    label_smoothing: float = 0.0,
    image_format: str = "NHWC",
):
    if len(labels.shape) == 1:
        labels = jax.nn.one_hot(labels, num_classes)

    if "N" not in image_format:
        raise ValueError('Mixup requires "N" to be in "image_format".')

    batch_size = labels.shape[0]

    if cutmix_alpha > 0 and (mixup_alpha <= 0 or np.random.rand() < switch_prob):
        do_mixup = False
        do_cutmix = True
    elif mixup_alpha > 0:
        do_mixup = True
        do_cutmix = False
    else:
        return images, labels

    if do_mixup:
        weight = jax.random.beta(rng, mixup_alpha, mixup_alpha)
        weight *= jnp.ones((batch_size, 1))

        # Mixup inputs.
        # Shape calculations use np to avoid device memory fragmentation:
        image_weight_shape = np.ones((images.ndim))
        image_weight_shape[image_format.index("N")] = batch_size
        image_weight = jnp.reshape(weight, image_weight_shape.astype(jnp.int32))
        reverse = tuple(slice(images.shape[i]) if d != "N" else slice(-1, None, -1) for i, d in enumerate(image_format))
        mixed_images = image_weight * images + (1.0 - image_weight) * images[reverse]
        label_weight = weight

    elif do_cutmix:
        if image_format not in {"NHWC", "NTHWC"}:
            raise ValueError("Cutmix is only supported for inputs in format" f" NHWC or NTHWC. Got {image_format}.")
        cutmix_lambda = np.random.beta(cutmix_alpha, cutmix_alpha)

        y_min, y_max, x_min, x_max = get_random_bounding_box(images.shape[-3:-1], cutmix_lambda)
        image_mask = np.ones(images.shape)
        if image_format == "NHWC":
            image_mask[:, y_min:y_max, x_min:x_max, :] = 0.0
        else:
            image_mask[:, :, y_min:y_max, x_min:x_max, :] = 0.0
        height, width = images.shape[-3], images.shape[-2]

        mixed_images = images * image_mask + jnp.flip(images, axis=0) * (1.0 - image_mask)
        box_area = (y_max - y_min) * (x_max - x_min)
        label_weight = 1.0 - box_area / float(height * width)

    # Mixup label
    if label_smoothing > 0:
        labels = label_smoothing_fn(labels, label_smoothing)

    mixed_labels = label_weight * labels + (1.0 - label_weight) * labels[::-1]

    return mixed_images, mixed_labels
