import functools
import logging

import numpy as np
from torch.utils.data import DataLoader

import jax
from jax.lib import xla_bridge
from flax.training import train_state  # Useful dataclass to keep train state
import jax.numpy as jnp
import optax  # Import the optax module for Flax's optimizers

from utils.utils import save_image, tensor_to_img, gkern, logger_info
from utils.data import ImageDataset
from utils.tasks import get_task_batches
from models.utils import model_apply, save_model


# Define the loss function (mean squared error)
def mean_squared_error(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)


def blur_forward(blur_kernel, x):
    conv_output = jax.lax.conv_general_dilated(x,  # lhs = NCHW image tensor
                                               jnp.transpose(blur_kernel[..., :1], [3, 2, 0, 1]),
                                               (1, 1),
                                               'SAME',
                                               feature_group_count=1,
                                               batch_group_count=1)
    return conv_output


def apply_mask(mask, x):
    return mask * x


def apply_meas_op(x, type_operator, mask, blur):
    branches = [lambda x: x,
                lambda x: apply_mask(mask, x),
                lambda x: blur_forward(blur, x)]
    return jax.lax.switch(type_operator[0], branches, x)


# Define the unsupervised loss function (mean squared error)
def mean_squared_error_unsupervised(predictions, targets, type_operator, blur, mask):
    op_pred = apply_meas_op(predictions, type_operator, mask, blur)
    op_targets = apply_meas_op(targets, type_operator, mask, blur)
    return jnp.mean((op_pred - op_targets) ** 2)


# Adam optimizer setup
def get_optimizer(learning_rate=0.1, clip_params_norm=1):
    return optax.chain(
        optax.adam(learning_rate=learning_rate, eps_root=1e-6),
        optax.clip_by_global_norm(clip_params_norm),
    )


def init_train_state(model, random_key, shape, learning_rate) -> train_state.TrainState:
    # Initialize the Model
    variables = model.init(random_key, jnp.ones(shape), jnp.array([0]))
    # Create the optimizer
    optimizer = get_optimizer(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=variables['params']
    )


def init_optimizer_state(model, params, learning_rate=1e-3):
    """
    params = parameters of the network, not the state
    """
    # optimizer = optax.adam(learning_rate)
    optimizer = get_optimizer(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params
    )



batched_model_apply = jax.vmap(model_apply, in_axes=(None, 0, 0, 0))


@jax.jit
def step_single_state_supervised(
        state: train_state.TrainState, params_ref, image_input: jnp.ndarray, image_target: jnp.ndarray,
        type_operator: jnp.ndarray, reg_param=1.0
):
    def loss_fn(params):
        pred = state.apply_fn({'params': params}, image_input, type_operator)
        loss = mean_squared_error(pred, image_target) + reg_param*mean_squared_error(jax.flatten_util.ravel_pytree(params)[0],
                                                                                     jax.flatten_util.ravel_pytree(params_ref)[0])
        return loss, pred

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss_value, pred), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss_value


@jax.jit
def step_single_state_unsupervised(
        state: train_state.TrainState, params_ref, image_input: jnp.ndarray, image_target: jnp.ndarray,
        type_operator: jnp.ndarray, blur, mask, reg_param=1.0,
):
    def loss_fn(params):
        pred = state.apply_fn({'params': params}, image_input, type_operator)
        loss = mean_squared_error_unsupervised(pred, image_target, type_operator, blur, mask) + reg_param*mean_squared_error(jax.flatten_util.ravel_pytree(params)[0],
                                                                                     jax.flatten_util.ravel_pytree(params_ref)[0])
        return loss, pred

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss_value, pred), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss_value


# Training function
def perform_steps_single_task(rng, model, params, params_ref, x_input_task, x_target_task, type_operator_task, num_inner_steps=100,
                              batch_size=32, shuffle=False, debug=False, reg_param=1.0, supervised=True, blur=None, mask=None):
    loss_list = []

    state_single = init_optimizer_state(model, params, learning_rate=1e-4)

    for i in range(num_inner_steps):

        if shuffle:  # Shuffle the batch into sub-batches
            rng, subrng = jax.random.split(rng)
            indices = jax.random.permutation(subrng, x_target_task.shape[0])
            x_target_shuffled = x_target_task[indices]
            x_input_shuffled = x_input_task[indices]

            x_target_batch = x_target_shuffled[0:batch_size, ...]
            x_input_batch = x_input_shuffled[0:batch_size, ...]
        else:
            x_target_batch = x_target_task[0:batch_size, ...]
            x_input_batch = x_input_task[0:batch_size, ...]

        if supervised:
            state_single, loss_value = step_single_state_supervised(state_single, params_ref, x_input_batch, x_target_batch,
                                                        type_operator_task, reg_param=reg_param)
        else:
            state_single, loss_value = step_single_state_unsupervised(state_single, params_ref, x_input_batch,
                                                                      x_target_batch, type_operator_task, blur,
                                                                      mask, reg_param=reg_param)
        loss_list.append(loss_value)

    jnp_loss = jnp.array(loss_list)

    return state_single, jnp_loss


def test_model(state_model_test, outer_state, inner_state, results_folder='results_denoising', grayscale=False,
               supervised=True):
    path_test_data = '/pth/to/Set3C/'

    dataset_test = ImageDataset(path_test_data, grayscale=grayscale)
    dataloader = DataLoader(dataset_test, batch_size=1, shuffle=False,
                            drop_last=False)  # We drop last because I cannot yet handle batches with varying sizes

    blur_kernel = gkern(kernlen=7, std=0.5)

    batched_model_forward = jax.vmap(model_apply, in_axes=(None, None, 0, 0))
    batched_model_forward_params = jax.vmap(model_apply, in_axes=(None, 0, 0, 0))

    rng = jax.random.PRNGKey(0)
    mask = jax.random.choice(rng, 2, shape=(1, 1, 256, 256), p=np.asarray([0.5, 0.5]))

    for im_index, im in enumerate(dataloader):
        _, x_input_batch_tasks, type_operators_batch_tasks = get_task_batches(rng, im, x_target_batch=None,
                                                                              kernel=blur_kernel, mask=mask,
                                                                              supervised=supervised)

        x_output_meta = batched_model_forward(state_model_test, outer_state, x_input_batch_tasks, type_operators_batch_tasks)
        x_output_inner = batched_model_forward_params(state_model_test, inner_state, x_input_batch_tasks, type_operators_batch_tasks)

        for task in range(x_input_batch_tasks.shape[0]):

            save_image(tensor_to_img(x_output_meta[task]), results_folder+'x_test_'+str(im_index)+'_output_meta.png')
            save_image(tensor_to_img(x_input_batch_tasks[task]), results_folder+'x_test_'+str(im_index)+'_input_'+str(task)+'_task_'+str(task)+'.png')
            save_image(tensor_to_img(x_output_inner[task]), results_folder+'x_test_'+str(im_index)+'_output_'+str(task)+'_task_'+str(task)+'.png')


# Training function
def train_outer(rng, dataset_train, dataset_test, model, meta_state, meta_state_test, num_inner_steps=1, batch_size=32,
                batch_size_inner=32,
                epochs=2, debug=True, mask=None, results_folder='results/', log_filename='training.log',
                reg_param_inner=0.01, grayscale=False, tasks=None, supervised=True):

    def non_batched_inner_step(rng, model, x_input, x_target, type_operator, params, params_ref, num_inner_steps=1,
                               batch_size=2,
                               shuffle=False, debug=False, reg_param=0.1,
                               supervised=True, blur=None, mask=None):
        inner_state, losses = perform_steps_single_task(rng, model, params, params_ref, x_input, x_target,
                                                        type_operator, debug=debug, batch_size=batch_size,
                                                        num_inner_steps=num_inner_steps, reg_param=reg_param,
                                                        supervised=supervised, blur=blur, mask=mask)
        return inner_state, losses

    blur_kernel = gkern(kernlen=7, std=0.5)

    non_batched_inner_step_ = functools.partial(non_batched_inner_step, num_inner_steps=num_inner_steps,
                                                batch_size=batch_size_inner, shuffle=False,
                                                debug=False, reg_param=reg_param_inner,
                                                supervised=supervised, mask=mask, blur=blur_kernel)  # setting generic, non-vmapped, kwargs
    perform_step_batch = jax.vmap(non_batched_inner_step_, in_axes=(None, None, 0, 0, 0, None, None))

    def batched_inner_step(rng, model, x_input_batch, x_target_batch, type_operator, params):
        inner_state, losses = perform_step_batch(rng, model, x_input_batch, x_target_batch, type_operator, params,
                                                 params)
        return inner_state, losses


    trainin_device = xla_bridge.get_backend().platform

    # Create a DataLoader to iterate through the dataset
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True,
                                  drop_last=True)  # We drop last because I cannot yet handle batches with varying sizes
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True,
                                 drop_last=True)  # We drop last because I cannot yet handle batches with varying sizes

    logname = 'trainlog'
    logger_info(logname, log_path=log_filename)
    logger = logging.getLogger(logname)

    logger.info('Starting training on {}'.format(str(trainin_device)))

    losses_outer = []
    losses_inner = []

    kernel = gkern(kernlen=7, std=0.5)

    @jax.jit
    def outer_step(meta_state: train_state.TrainState, test_batch: tuple, train_batch: tuple):

        def loss_fn_nonbatched(params, image_input, image_target, type_operator):
            pred = meta_state.apply_fn({'params': params}, image_input, type_operator)
            loss = mean_squared_error(pred, image_target)
            return loss

        batched_loss_fn = jax.vmap(loss_fn_nonbatched, in_axes=(0, 0, 0, 0))  # batched wrt params as well
        # batched_loss_fn_partial = jax.vmap(loss_fn_nonbatched, in_axes=(None, 0, 0, 0))  # batched wrt params as well

        def batched_outer_loss(meta_params):

            # First, compute the inner step
            image_train_target, image_train_input, type_operator_train = train_batch
            inner_state, inner_losses = batched_inner_step(rng, model, image_train_input, image_train_target,
                                                           type_operator_train,
                                                           meta_params)

            # Second, compute the outer step
            image_test_target, image_test_input, type_operator_test = test_batch
            loss_values = batched_loss_fn(inner_state.params, image_test_input, image_test_target, type_operator_test)

            return loss_values.mean()

        # Compute outer loss, grads and step
        gradient_fn = jax.value_and_grad(batched_outer_loss, has_aux=False)
        loss_value, grads = gradient_fn(meta_state.params)
        meta_state = meta_state.apply_gradients(grads=grads)
        return meta_state, loss_value

    for epoch in range(epochs):

        losses_epoch = []

        for i, (x_input_batch_test, x_target_batch_test) in enumerate(dataloader_test):
            rng, subrng = jax.random.split(rng)

            # 1. Select a batch from the training dataset
            x_input_batch_train, x_target_batch_train = next(iter(dataloader_train))

            # 2. Generate the tasks for training and testing
            train_batch = get_task_batches(rng, x_input_batch_train, x_target_batch_train, kernel, mask=mask,
                                           tasks=tasks, supervised=supervised)
            test_batch = get_task_batches(rng, x_input_batch_test, x_target_batch_test, kernel, mask=mask,
                                          tasks=tasks, supervised=supervised)

            # 3. Perform the outer training step
            meta_state, loss_value = outer_step(meta_state, test_batch, train_batch)
            losses_outer.append([epoch, i, loss_value])
            losses_epoch.append(loss_value.item())
            logger.info('Iter {:d} \t Avg train loss: {:4e}'.format(i, loss_value.item()))

        loss_epoch = jnp.array(losses_epoch).mean()
        logger.info('Epoch {:d} \t Avg train loss: {:4e}'.format(epoch + 1, loss_epoch))

        if debug and ((epoch % 20 == 0 and epoch < 200) or epoch % 200 == 0):

            save_model(meta_state, results_folder + 'ckpt_meta.ckpt')

            x_target_test_tasks, x_input_test_tasks, type_operators_test_tasks = test_batch
            x_target_train_tasks, x_input_train_tasks, type_operators_train_tasks = train_batch

            inner_state, _ = batched_inner_step(rng, model, x_input_train_tasks, x_target_train_tasks,
                                                type_operators_train_tasks,
                                                meta_state.params)

            save_model(inner_state, results_folder + 'ckpt_inner.ckpt')

            test_model(meta_state_test, meta_state, inner_state, results_folder=results_folder, grayscale=grayscale,
                       supervised=supervised)

            x_output_inner = batched_model_apply(meta_state, inner_state, x_input_train_tasks,
                                                 type_operators_train_tasks)

            for task_id in range(type_operators_test_tasks.shape[0]):
                x_output_meta = meta_state.apply_fn({'params': meta_state.params}, x_input_test_tasks[task_id], type_operators_test_tasks[task_id])
                save_image(tensor_to_img(x_output_meta[0]), results_folder+'x_train_output_meta_epoch_'+str(epoch)+'_task_'+str(task_id)+'.png')
                save_image(tensor_to_img(x_input_test_tasks[task_id, 0]), results_folder+'x_train_input_meta_epoch_'+str(epoch)+'_task_'+str(task_id)+'.png')
                save_image(tensor_to_img(x_target_test_tasks[task_id, 0]), results_folder+'x_train_target_meta_epoch_'+str(epoch)+'_task_'+str(task_id)+'.png')

                save_image(tensor_to_img(x_output_inner[task_id, 0]), results_folder+'x_train_output_inner_epoch_'+str(epoch)+'_task_'+str(task_id)+'.png')
                save_image(tensor_to_img(x_input_train_tasks[task_id, 0]),  results_folder+'x_train_input_inner_epoch_'+str(epoch)+'_task_'+str(task_id)+'.png')
                save_image(tensor_to_img(x_target_train_tasks[task_id, 0]), results_folder+'x_train_target_inner_epoch_'+str(epoch)+'_task_'+str(task_id)+'.png')

    # Compute the inner state
    inner_state, _ = batched_inner_step(rng, model, x_input_train_tasks, x_target_train_tasks,
                                        type_operators_train_tasks,
                                        meta_state.params)

    jnp_loss = jnp.array(losses_outer)
    jnp_loss_epoch = jnp.array(losses_epoch)

    info = {'losses_outer': jnp_loss,
            'losses_epoch': jnp_loss_epoch}

    return meta_state, inner_state, info
