import numpy as np
import jax
import jax.numpy as jnp


# Deblurring
def blur_operation(x, blur_kernel):
    conv_output = jax.lax.conv_general_dilated(x,    # lhs = NCHW image tensor
                                  jnp.transpose(blur_kernel, [3, 2, 0, 1]),
                                 (1, 1),
                                 'SAME',
                                 feature_group_count=x.shape[1],
                                 batch_group_count=1)
    return conv_output


def get_task_batches(rng, x_input_batch, x_target_batch=None, kernel=None, mask=None, debug=False, select_task=None,
                     tasks=None, supervised=True):
    r"""
    Inputs: torch tensors of shape (N,C,H,W).

    Outputs: jnp arrays of shape (T,N,C,H,W).
    """
    x_input_batch = jnp.array(x_input_batch)

    if x_target_batch is not None:
        x_target_batch = jnp.array(x_target_batch)
    else:
        x_target_batch = x_input_batch

    # Denoising
    if supervised:
        x_target_batch_den = x_input_batch
    else:
        rng_new, subrng_new = jax.random.split(rng)
        x_target_batch_den = x_input_batch + 0.05 * jax.random.normal(rng_new, x_target_batch.shape)
    x_input_batch_den = x_input_batch + 0.05 * jax.random.normal(rng, x_target_batch.shape)

    # TV
    if supervised:
        x_target_batch_tv = x_target_batch
        x_input_batch_tv = x_input_batch
    else:
        rng_new, subrng_new = jax.random.split(rng)
        x_target_batch_tv = x_target_batch + 0.05 * jax.random.normal(rng_new, x_target_batch.shape)
        x_input_batch_tv = x_input_batch + 0.05 * jax.random.normal(rng, x_input_batch.shape)

    # Deblurring
    x_target_batch_deblur = jnp.expand_dims(x_input_batch, axis=1)
    batched_blur_op = jax.vmap(blur_operation, in_axes=(0, None))
    x_input_batch_deblur = batched_blur_op(x_target_batch_deblur, kernel)
    x_target_batch_deblur = x_target_batch_deblur[:, 0, ...]
    x_input_batch_deblur = x_input_batch_deblur[:, 0, ...]
    x_input_batch_deblur = x_input_batch_deblur[:, :x_input_batch.shape[1], ...]

    # Inpainting
    x_target_batch_inpaint = x_input_batch
    if mask is None:
        mask = jax.random.choice(rng, 2,
                                 shape=(x_input_batch.shape[0], 1, x_target_batch.shape[-2], x_target_batch.shape[-1]),
                                 p=np.asarray([0.1, 0.9]))
    x_input_batch_inpaint = jnp.array(x_input_batch) * mask

    # print(x_target_batch_den.shape,
    #       x_input_batch_den.shape,
    #       x_target_batch_tv.shape,
    #       x_input_batch_tv.shape,
    #       x_target_batch_deblur.shape,
    #       x_input_batch_deblur.shape,
    #       x_target_batch_inpaint.shape,
    #       x_input_batch_inpaint.shape)

    # if select_task is None and tasks is None and supervised:
    #     x_target_batch = jnp.array([x_target_batch_den, x_target_batch_tv, x_target_batch_deblur, x_target_batch_inpaint])
    #     x_input_batch = jnp.array([x_input_batch_den, x_input_batch_tv, x_input_batch_deblur, x_input_batch_inpaint])
    #     type_operators_batch = jnp.array([[0], [0], [2], [1]])
    if select_task is None and tasks is None and supervised:  # For comparison fairness we do this instead
        x_target_batch = jnp.array([x_target_batch_den, x_target_batch_tv, x_target_batch_deblur, x_target_batch_inpaint])
        x_input_batch = jnp.array([x_input_batch_den, x_input_batch_tv, x_input_batch_deblur, x_input_batch_inpaint])
        type_operators_batch = jnp.array([[0], [0], [2], [1]])
    elif select_task is None and tasks is None and not supervised:
        x_target_batch = jnp.array([x_target_batch_den, x_target_batch_tv, x_target_batch_deblur, x_target_batch_inpaint])
        x_input_batch = jnp.array([x_input_batch_den, x_input_batch_tv, x_input_batch_deblur, x_input_batch_inpaint])
        type_operators_batch = jnp.array([[0], [0], [2], [1]])
    else:
        if select_task == 0:
            x_target_batch = x_target_batch_den
            x_input_batch = x_input_batch_den
            type_operators_batch = jnp.array([0])
        elif select_task == 1:
            x_target_batch = x_target_batch_tv
            x_input_batch = x_input_batch_tv
            type_operators_batch = jnp.array([0])
        elif select_task == 2:
            x_target_batch = x_target_batch_deblur
            x_input_batch = x_input_batch_deblur
            type_operators_batch = jnp.array([2])
        elif select_task == 3:
            x_target_batch = x_target_batch_inpaint
            x_input_batch = x_input_batch_inpaint
            type_operators_batch = jnp.array([1])
        else:
            raise NotImplementedError

    return x_target_batch, x_input_batch, type_operators_batch
