import logging

import numpy as np
from torch.utils.data import DataLoader

import jax
from jax.lib import xla_bridge
from flax.training import train_state  # Useful dataclass to keep train state
import jax.numpy as jnp
import optax  # Import the optax module for Flax's optimizers

from utils.utils import save_image, tensor_to_img, gkern, logger_info
from utils.data import ImageDataset
from utils.tasks import get_task_batches
from models.utils import model_apply, save_model


# Define the loss function (mean squared error)
def mean_squared_error(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)


# Adam optimizer setup
def get_optimizer(learning_rate=0.1, clip_params_norm=1):
    return optax.chain(
        optax.adam(learning_rate=learning_rate, eps_root=1e-6),
        optax.clip_by_global_norm(clip_params_norm),
    )


def init_train_state(model, random_key, shape, learning_rate) -> train_state.TrainState:
    # Initialize the Model
    variables = model.init(random_key, jnp.ones(shape), jnp.array([0]))
    # Create the optimizer
    optimizer = get_optimizer(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=variables['params']
    )


def init_optimizer_state(model, params, learning_rate=1e-3):
    """
    params = parameters of the network, not the state
    """
    optimizer = optax.adam(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params
    )



def test_model(state_model_test, outer_state, results_folder='results_denoising', task_id=0, grayscale=False):
    path_test_data = '/pth/to/Set3C/'

    dataset_test = ImageDataset(path_test_data, grayscale=grayscale)
    dataloader = DataLoader(dataset_test, batch_size=1, shuffle=False,
                            drop_last=False)  # We drop last because I cannot yet handle batches with varying sizes

    blur_kernel = gkern(kernlen=7, std=0.5)

    rng = jax.random.PRNGKey(0)
    mask = jax.random.choice(rng, 2, shape=(1, 1, 256, 256), p=np.asarray([0.5, 0.5]))

    for im_index, im in enumerate(dataloader):
        _, x_input, type_operators_batch_tasks = get_task_batches(rng, im, x_target_batch=None,
                                                                  kernel=blur_kernel, mask=mask,
                                                                  select_task=task_id)

        x_output = model_apply(state_model_test, outer_state, x_input, type_operators_batch_tasks)

        save_image(tensor_to_img(x_output), results_folder+'x_test_'+str(im_index)+'_output_meta.png')
        save_image(tensor_to_img(x_input), results_folder+'x_test_'+str(im_index)+'_input.png')


# Training function
def train_basic(rng, dataset_train, state, batch_size=1,
                epochs=2, debug=True, results_folder='results/', log_filename='training.log',
                grayscale=False):


    trainin_device = xla_bridge.get_backend().platform

    # Create a DataLoader to iterate through the dataset
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True,
                                  drop_last=True)  # We drop last because I cannot yet handle batches with varying sizes

    logname = 'trainlog'
    logger_info(logname, log_path=log_filename)
    logger = logging.getLogger(logname)

    logger.info('Starting training on {}'.format(str(trainin_device)))

    losses_outer = []
    losses_inner = []

    kernel = gkern(kernlen=7, std=0.5)

    @jax.jit
    def train_step(
            state: train_state.TrainState, input_batch: tuple
    ):
        image_target, meas_input, mask, backproj = input_batch
        op_type = jnp.array([0])  # A blanket variable for now

        def loss_fn(params):
            print(backproj.shape)
            u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
                             jnp.transpose(params['weight_0'], [3, 2, 0, 1]),
                             (1, 1),  # window strides
                             'SAME')
            pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj, u)
            loss = mean_squared_error(pred, image_target)
            return loss, pred

        gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss_value, pred), grads = gradient_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss_value

    for epoch in range(epochs):

        losses_epoch = []

        for i, train_batch in enumerate(dataloader_train):
            rng, subrng = jax.random.split(rng)
            image_target, meas_input, backproj, mask = train_batch
            train_batch = (jnp.array(image_target), jnp.array(meas_input), jnp.array(backproj), jnp.array(mask))

            # print('train_batch shapes')
            # for _ in range(4):
            #     print(train_batch[_].shape)

            # # 2. Perform the outer training step
            state, loss_value = train_step(state, train_batch)
            losses_outer.append([epoch, i, loss_value])
            losses_epoch.append(loss_value.item())

        loss_epoch = jnp.array(losses_epoch).mean()
        logger.info('Epoch {:d} \t Avg train loss: {:4e}'.format(epoch + 1, loss_epoch))

        if debug and ((epoch % 20 == 0 and epoch < 200) or epoch % 200 == 0):

            x_target, meas, mask, backproj = train_batch
            x_target, meas, mask, backproj = jnp.array(x_target), jnp.array(meas), jnp.array(mask), jnp.array(backproj)

            save_model(state, results_folder + 'ckpt.ckpt')

            u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
                             jnp.transpose(state.params['weight_0'], [3, 2, 0, 1]),
                             (1, 1),  # window strides
                             'SAME')
            x_output, _ = state.apply_fn({'params': state.params}, meas, jnp.array([0]), backproj, u)

            save_image(tensor_to_img(x_output[0]), results_folder+'x_train_output_epoch_'+str(epoch)+'.png')
            save_image(tensor_to_img(backproj[0]),  results_folder+'x_train_input_epoch_'+str(epoch)+'.png')
            save_image(tensor_to_img(x_target[0]),  results_folder+'x_train_target_epoch_'+str(epoch)+'.png')

    jnp_loss = jnp.array(losses_outer)
    jnp_loss_epoch = jnp.array(losses_epoch)

    info = {'losses_outer': jnp_loss,
            'losses_epoch': jnp_loss_epoch}

    return state, info


def view_as_real(x):
    return jnp.concatenate([jnp.real(x), jnp.imag(x)], axis=-1)


def apply_meas_op_asreal(x, mask):
    x = jnp.fft.fftn(  # type: ignore
        x, axes=(-2, -1), norm='ortho'
    )
    x = jnp.fft.fftshift(x, axes=[-2, -1])
    out = x * mask
    out = jnp.fft.ifftshift(out, axes=[-2, -1])
    return view_as_real(out)


def train_single_step(rng, dataset_train, state, params_ref, batch_size=1,
                num_steps=2, debug=True, results_folder='results/', log_filename='training.log',
                grayscale=False, supervised=True, reg_param=1e-3):


    trainin_device = xla_bridge.get_backend().platform

    # Create a DataLoader to iterate through the dataset
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True,
                                  drop_last=True)  # We drop last because I cannot yet handle batches with varying sizes

    logname = 'trainlog'
    logger_info(logname, log_path=log_filename)
    logger = logging.getLogger(logname)

    logger.info('Starting training on {}'.format(str(trainin_device)))

    losses_outer = []
    losses_inner = []

    def train_step_supervised(
            state: train_state.TrainState, input_batch: tuple, params_ref, reg_param
    ):
        image_target, meas_input, mask, backproj = input_batch
        op_type = jnp.array([0])  # A blanket variable for now

        def loss_fn(params):

            u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
                             jnp.transpose(params['weight_0'], [3, 2, 0, 1]),
                             (1, 1),  # window strides
                             'SAME')
            pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj*0, u*0)

            loss = mean_squared_error(pred, image_target) + reg_param*mean_squared_error(jax.flatten_util.ravel_pytree(params)[0],
                                                                                         jax.flatten_util.ravel_pytree(params_ref)[0])
            return loss, pred

        gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss_value, pred), grads = gradient_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss_value


    def train_step_unsupervised(
            state: train_state.TrainState, input_batch: tuple, params_ref, reg_param
    ):
        image_target, meas_input, mask, backproj = input_batch
        op_type = jnp.array([0])  # A blanket variable for now

        def loss_fn(params):
            # CASE uPDNET
            u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
                             jnp.transpose(params['weight_0'], [3, 2, 0, 1]),
                             (1, 1),  # window strides
                             'SAME')
            pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj, u)
            # # CASE PDNET
            # pred, _ = state.apply_fn({'params': params}, meas_input, op_type, backproj)
            loss = mean_squared_error(apply_meas_op_asreal(pred, mask), apply_meas_op_asreal(image_target, mask)) + \
                   reg_param*mean_squared_error(jax.flatten_util.ravel_pytree(params)[0], jax.flatten_util.ravel_pytree(params_ref)[0])
            return loss, pred

        gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss_value, pred), grads = gradient_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss_value

    train_step = jax.jit(train_step_supervised if supervised else train_step_unsupervised)

    losses_epoch = []

    step = 0
    for epoch in range(num_steps):
        for i, train_batch in enumerate(dataloader_train):
            if step == 0:
                save_model(state, results_folder + 'ckpt_step_' + str(step) + '.ckpt')

            rng, subrng = jax.random.split(rng)
            image_target, meas_input, backproj, mask = train_batch
            train_batch = (jnp.array(image_target), jnp.array(meas_input), jnp.array(backproj), jnp.array(mask))

            # print('train_batch shapes')
            # for _ in range(4):
            #     print(train_batch[_].shape)

            # # 2. Perform the outer training step
            # state, loss_value = train_step(state, train_batch)
            state, loss_value = train_step(state, train_batch, params_ref, reg_param)

            losses_outer.append([i, loss_value])
            losses_epoch.append(loss_value.item())

            loss_epoch = jnp.array(losses_epoch).mean()
            logger.info('Step {:d} \t Avg train loss: {:4e}'.format(step+1, loss_epoch))

            if debug and step in [1, 10, 20, 50, 100, 200]:

                x_target, meas, mask, backproj = train_batch
                x_target, meas, mask, backproj = jnp.array(x_target), jnp.array(meas), jnp.array(mask), jnp.array(backproj)

                save_model(state, results_folder + 'ckpt_step_'+str(step)+'.ckpt')

                # CASE uPDNET
                u = jax.lax.conv(backproj,  # lhs = NCHW image tensor
                                 jnp.transpose(state.params['weight_0'], [3, 2, 0, 1]),
                                 (1, 1),  # window strides
                                 'SAME')
                x_output, _ = state.apply_fn({'params': state.params}, meas, jnp.array([0]), backproj*0, u*0)
                # CASE PDNET
                # x_output, _ = state.apply_fn({'params': state.params}, meas, jnp.array([0]), backproj)

                save_image(tensor_to_img(x_output[0]), results_folder+'x_train_output_step_'+str(step)+'.png')
                save_image(tensor_to_img(backproj[0]), results_folder+'x_train_input_step_'+str(step)+'.png')
                save_image(tensor_to_img(x_target[0]), results_folder+'x_train_target_step_'+str(step)+'.png')

            step = step + 1
            if step > num_steps:
                break
        if step > num_steps:
            break

    jnp_loss = jnp.array(losses_outer)
    jnp_loss_epoch = jnp.array(losses_epoch)

    info = {'losses_outer': jnp_loss,
            'losses_epoch': jnp_loss_epoch}

    return state, info
