import functools
import logging

import numpy as np
from torch.utils.data import DataLoader

import jax
from jax.lib import xla_bridge
from flax.training import train_state  # Useful dataclass to keep train state
import jax.numpy as jnp
import optax  # Import the optax module for Flax's optimizers

from utils.utils import save_image, tensor_to_img, gkern, logger_info
from utils.data import ImageDataset
from utils.tasks import get_task_batches
from models.utils import model_apply, save_model


# Define the loss function (mean squared error)
def mean_squared_error(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)


# Adam optimizer setup
def get_optimizer(learning_rate=0.1, clip_params_norm=1):
    return optax.chain(
        optax.adam(learning_rate=learning_rate, eps_root=1e-6),
        optax.clip_by_global_norm(clip_params_norm),
    )


def init_train_state(model, random_key, shape, learning_rate) -> train_state.TrainState:
    # Initialize the Model
    variables = model.init(random_key, jnp.ones(shape), jnp.array([0]))
    # Create the optimizer
    optimizer = get_optimizer(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=variables['params']
    )


def init_optimizer_state(model, params, learning_rate=1e-3):
    """
    params = parameters of the network, not the state
    """
    optimizer = optax.adam(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params
    )


def test_model(state_model_test, outer_state, results_folder='results_denoising', task_id=0, grayscale=False):
    path_test_data = '/pth/to/Set3C/'

    dataset_test = ImageDataset(path_test_data, grayscale=grayscale)
    dataloader = DataLoader(dataset_test, batch_size=1, shuffle=False,
                            drop_last=False)  # We drop last because I cannot yet handle batches with varying sizes

    blur_kernel = gkern(kernlen=7, std=0.5)

    rng = jax.random.PRNGKey(0)
    mask = jax.random.choice(rng, 2, shape=(1, 1, 256, 256), p=np.asarray([0.5, 0.5]))

    for im_index, im in enumerate(dataloader):
        _, x_input, type_operators_batch_tasks = get_task_batches(rng, im, x_target_batch=None,
                                                                  kernel=blur_kernel, mask=mask,
                                                                  select_task=task_id)

        x_output = model_apply(state_model_test, outer_state, x_input, type_operators_batch_tasks)

        save_image(tensor_to_img(x_output), results_folder+'x_test_'+str(im_index)+'_output_meta.png')
        save_image(tensor_to_img(x_input), results_folder+'x_test_'+str(im_index)+'_input.png')


# # Training function
# def train_basic(rng, dataset_train, state, batch_size=1,
#                 epochs=2, debug=True, results_folder='results/', log_filename='training.log',
#                 grayscale=False):
#
#
#     trainin_device = xla_bridge.get_backend().platform
#
#     # Create a DataLoader to iterate through the dataset
#     dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True,
#                                   drop_last=True)  # We drop last because I cannot yet handle batches with varying sizes
#
#     logname = 'trainlog'
#     logger_info(logname, log_path=log_filename)
#     logger = logging.getLogger(logname)
#
#     logger.info('Starting training on {}'.format(str(trainin_device)))
#
#     losses_outer = []
#     losses_inner = []
#
#     kernel = gkern(kernlen=7, std=0.5)
#
#     @jax.jit
#     def train_step(
#             state: train_state.TrainState, input_batch: tuple
#     ):
#         image_target, meas_input, mask, backproj = input_batch
#         op_type = jnp.array([0])  # A blanket variable for now
#
#         def loss_fn(params):
#             print(backproj.shape)
#             u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
#                              jnp.transpose(params['weight_0'], [3, 2, 0, 1]),
#                              (1, 1),  # window strides
#                              'SAME')
#             pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj, u)
#             loss = mean_squared_error(pred, image_target)
#             return loss, pred
#
#         gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
#         (loss_value, pred), grads = gradient_fn(state.params)
#         state = state.apply_gradients(grads=grads)
#         return state, loss_value
#
#     for epoch in range(epochs):
#
#         losses_epoch = []
#
#         for i, train_batch in enumerate(dataloader_train):
#             rng, subrng = jax.random.split(rng)
#             image_target, meas_input, backproj, mask = train_batch
#             train_batch = (jnp.array(image_target), jnp.array(meas_input), jnp.array(backproj), jnp.array(mask))
#
#             # print('train_batch shapes')
#             # for _ in range(4):
#             #     print(train_batch[_].shape)
#
#             # # 2. Perform the outer training step
#             state, loss_value = train_step(state, train_batch)
#             losses_outer.append([epoch, i, loss_value])
#             losses_epoch.append(loss_value.item())
#
#         loss_epoch = jnp.array(losses_epoch).mean()
#         logger.info('Epoch {:d} \t Avg train loss: {:4e}'.format(epoch + 1, loss_epoch))
#
#         if debug and ((epoch % 20 == 0 and epoch < 200) or epoch % 200 == 0):
#
#             x_target, meas, mask, backproj = train_batch
#             x_target, meas, mask, backproj = jnp.array(x_target), jnp.array(meas), jnp.array(mask), jnp.array(backproj)
#
#             save_model(state, results_folder + 'ckpt.ckpt')
#
#             u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
#                              jnp.transpose(state.params['weight_0'], [3, 2, 0, 1]),
#                              (1, 1),  # window strides
#                              'SAME')
#             x_output, _ = state.apply_fn({'params': state.params}, meas, jnp.array([0]), backproj, u)
#
#             save_image(tensor_to_img(x_output[0]), results_folder+'x_train_output_epoch_'+str(epoch)+'.png')
#             save_image(tensor_to_img(backproj[0]),  results_folder+'x_train_input_epoch_'+str(epoch)+'.png')
#             save_image(tensor_to_img(x_target[0]),  results_folder+'x_train_target_epoch_'+str(epoch)+'.png')
#
#     jnp_loss = jnp.array(losses_outer)
#     jnp_loss_epoch = jnp.array(losses_epoch)
#
#     info = {'losses_outer': jnp_loss,
#             'losses_epoch': jnp_loss_epoch}
#
#     return state, info


def downscale_op(data, sr_factor=4):
    return data[:, :, ::sr_factor, ::sr_factor]

def non_batched_downscale_op(data, sr_factor=4):
    return data[:, :, 0:32:2, 0:32:2]


def train_single_step(rng, dataset_train, state, params_ref, batch_size=1,
                num_steps=2, debug=True, results_folder='results/', log_filename='training.log',
                grayscale=False, sr_factor=2, supervised=1, reg_param=1e-3):


    trainin_device = xla_bridge.get_backend().platform

    # Create a DataLoader to iterate through the dataset
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True,
                                  drop_last=True)  # We drop last because I cannot yet handle batches with varying sizes

    logname = 'trainlog'
    logger_info(logname, log_path=log_filename)
    logger = logging.getLogger(logname)

    logger.info('Starting training on {}'.format(str(trainin_device)))

    losses_outer = []
    losses_inner = []

    kernel = gkern(kernlen=7, std=0.5)

    # @jax.jit
    def train_step_supervised(
            state: train_state.TrainState, input_batch: tuple, params_ref, sr_factor: int, reg_param: float
    ):
        image_target, image_input = input_batch
        op_type = jnp.array([0])  # A blanket variable for now
        image_init = 0*image_target

        def loss_fn(params):
            # CASE uPDNET
            u = jax.lax.conv(image_init,  # lhs = NCHW image tensor
                             jnp.transpose(params['weight_0'], [3, 2, 0, 1]),
                             (1, 1),  # window strides
                             'SAME')
            pred, _ = state.apply_fn({'params': params}, image_input, op_type, image_init, u)
            # # CASE PDNET
            # pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj)
            loss = mean_squared_error(pred, image_target) + reg_param*mean_squared_error(jax.flatten_util.ravel_pytree(params)[0],
                                                                                         jax.flatten_util.ravel_pytree(params_ref)[0])
            return loss, pred

        gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss_value, pred), grads = gradient_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss_value

    # @jax.jit
    def train_step_unsupervised(
            state: train_state.TrainState, input_batch: tuple, params_ref, sr_factor: int, reg_param: float
    ):
        image_target, image_input = input_batch
        op_type = jnp.array([0])  # A blanket variable for now
        image_init = 0 * image_target

        def loss_fn(params):
            # CASE uPDNET
            u = jax.lax.conv(image_init,  # lhs = NCHW image tensor
                             jnp.transpose(params['weight_0'], [3, 2, 0, 1]),
                             (1, 1),  # window strides
                             'SAME')
            pred, _ = state.apply_fn({'params': params}, image_input, op_type, image_init, u)
            # # CASE PDNET
            # pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj)
            loss = mean_squared_error(non_batched_downscale_op(pred, sr_factor=sr_factor), non_batched_downscale_op(image_target, sr_factor=sr_factor)) + reg_param*mean_squared_error(jax.flatten_util.ravel_pytree(params)[0],
                                                                                     jax.flatten_util.ravel_pytree(params_ref)[0])
            return loss, pred

        gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss_value, pred), grads = gradient_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss_value

    train_step = jax.jit(train_step_supervised if supervised else train_step_unsupervised)

    losses_epoch = []

    step = 0
    for epoch in range(num_steps):
        for i, train_batch in enumerate(dataloader_train):
            rng, subrng = jax.random.split(rng)
            _, image_input = train_batch

            image_input = jnp.array(image_input)

            downscaled_image = downscale_op(image_input, sr_factor=sr_factor)
            train_batch = (image_input, downscaled_image)

            # print('train_batch shapes')
            # for _ in range(4):
            #     print(train_batch[_].shape)

            # # 2. Perform the outer training step
            state, loss_value = train_step(state, train_batch, params_ref, sr_factor, reg_param)

            losses_outer.append([i, loss_value])
            losses_epoch.append(loss_value.item())

            loss_epoch = jnp.array(losses_epoch).mean()
            logger.info('Step {:d} \t Avg train loss: {:4e}'.format(step+1, loss_epoch))

            if debug and step in [0, 1, 10, 20, 50, 100, 200, 1000, 10000]:

                _, image_input = train_batch

                image_target = jnp.array(image_input)

                downscaled_image = downscale_op(image_target, sr_factor=sr_factor)
                train_batch = (image_target, downscaled_image)

                image_init = 0*image_target

                save_model(state, results_folder + 'ckpt_step_'+str(step)+'.ckpt')

                # CASE uPDNET
                u = jax.lax.conv(image_init,  # lhs = NCHW image tensor
                                 jnp.transpose(state.params['weight_0'], [3, 2, 0, 1]),
                                 (1, 1),  # window strides
                                 'SAME')
                x_output, _ = state.apply_fn({'params': state.params}, downscaled_image, jnp.array([0]), image_init, u)
                # CASE PDNET
                # x_output, _ = state.apply_fn({'params': state.params}, meas, jnp.array([0]), backproj)

                save_image(tensor_to_img(x_output[0]), results_folder+'x_train_output_step_'+str(step)+'.png')
                save_image(tensor_to_img(image_target[0]), results_folder+'x_train_target_step_'+str(step)+'.png')

            step = step + 1
            if step > num_steps:
                break
        if step > num_steps:
            break

    jnp_loss = jnp.array(losses_outer)
    jnp_loss_epoch = jnp.array(losses_epoch)

    info = {'losses_outer': jnp_loss,
            'losses_epoch': jnp_loss_epoch}

    return state, info
