"""Main training function."""

import time
import sys

import tensorflow as tf
from PIL import Image, ImageDraw

from ncc.algos import core, schedulers
from ncc.utils import logx


def get_permutation_matrix(num_of_symbols, noise_prob):
    """Used in alternative noisy channel implementation."""
    permutation = tf.where(tf.random.uniform(()) >= noise_prob,
                           tf.range(0, num_of_symbols),
                           tf.random.shuffle(tf.range(0, num_of_symbols)))
    permutation_matrix = tf.linalg.LinearOperatorPermutation(
        permutation).to_dense()
    permutation_matrix = tf.cast(permutation_matrix, tf.float32)
    return permutation_matrix


@tf.function
def get_permutation_matrix_batch(batch_size, message_length, num_of_symbols,
        noise_prob):
    """Used in alternative noisy channel implementation."""
    max_seq_len = batch_size * message_length

    matrices_arr = tf.TensorArray(tf.float32, size=max_seq_len)
    for i in tf.range(max_seq_len):
        matrix = get_permutation_matrix(num_of_symbols, noise_prob)
        matrices_arr = matrices_arr.write(i, matrix)
    matrices = matrices_arr.stack()
    matrices = tf.reshape(matrices,
               (batch_size, message_length, num_of_symbols, num_of_symbols))
    return matrices


def noisy_channel(
        ds,
        ds_info,
        total_epochs,
        straight_through,
        kl_coef,
        noise,
        tau,
        tau_min,
        sender_kwargs,
        receiver_kwargs,
        gradient_trick=schedulers.periodic_scheduler,
        softmax_fn=core.gumbel_softmax,
        new_noise=0.0,
        ent_coef=0.0,
        learned_tau=True,
        learning_rate=1e-3,
        smooth_coef=0.995,
        log_every=100,
        eval_every=500,
        save_every=None,
        eval_cb=None,
        log_only_average=True,
        dump_debug_observations=False
):
    """Runs the training pipeline.

    Args:
        ds (tf.data.Dataset): The dataset used for training in the experiment.
        ds_info (dict): Dictionary with the following information:
            'message_length', 'sender_alphabet_size', 'receiver_alphabet_size'
            and features related dicts.
        total_epochs (int): Total number of training epochs.
        straight_through: either constant (bool) or callable setting
            if to use straight_through
        kl_coef (float): KL coefficient.
        ent_coef (float): Entropy coefficient.
        noise : either constant (float)
            or callable with scheduler of noise probability
        new_noise: noise (float) used in alternative noise implementation
        tau (float): The initial value of tau.
        tau_min (float): The hard lower constraint on the value of tau.
        sender_kwargs (dict): Keyword args for the sender.
        receiver_kwargs (dict): Keyword args for the receiver.
        gradient_trick: either constant (bool)  or callable setting
            if to use gradient trick
        softmax_fn (callable): Function providing probabilities.
        learned_tau (bool): Whether the KL coefficient is also learned.
        learning_rate (float): Learning rate.
        smooth_coef (float): Averaging coefficient.
        log_every (int): How often (in terms of gap between epochs) to log.
        eval_every (int): How often (in terms of gap between epochs) to eval.
        save_every (int): How often (in terms of gap between epochs) to save
            the neural networks weights.
        eval_cb (callable): Evaluation function.
        log_only_average (bool): if true long only averages of statistics
            otherwise also std, max and min
        dump_debug_observations: dumps the first batch of observations for debug
    """
    logger = logx.EpochLogger()
    logger.save_config(locals())

    noise_schedule_fn = schedulers.callable_or_constant(noise, True)
    new_noise_schedule_fn = schedulers.callable_or_constant(new_noise, True)
    straight_through_fn = \
        schedulers.callable_or_constant(straight_through, True)
    gradient_trick_fn = schedulers.callable_or_constant(gradient_trick, True)

    # Build an sender and receiver.
    sender = core.Sender(**sender_kwargs)
    receiver = core.Receiver(**receiver_kwargs)

    # Choose whether tau is learned
    if learned_tau:
        tau_var = tf.Variable(tau, dtype='float32', trainable=True)
        assert tau_min is not None, 'Please set minimal tau'
    else:
        assert tau_min is None, 'Please keep configs clean'
        tau_min = 0.0  # for the sake of avoiding if's
        tau_fn = schedulers.callable_or_constant(tau)
        # To be assigned in the training loop
        tau_var = tf.Variable(0.0, dtype='float32', trainable=True)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    @tf.function
    def learn_on_batch(images, labels, straight_through, gradient_trick,
            new_noise_probab):

        # Shape (B, message_length, receiver_alphabet).
        # pylint: disable=no-value-for-parameter
        labels_one_hot = tf.one_hot(labels, ds_info['receiver_alphabet_size'])

        with tf.GradientTape(persistent=True) as tape:
            # Clip tau.
            tau_clipped = tf.math.maximum(tau_var, tau_min)

            # Sender output.
            _, noisy_logits, latent_cross_entropy, \
                latent_entropy = sender(images, gradient_trick)

            # Generated message (either 1-hot or soft).
            message = softmax_fn(noisy_logits, tau_clipped,
                                 straight_through=straight_through)

            batch_size = message.shape[0]
            message_length = message.shape[1]
            num_of_symbols = message.shape[2]

            noise_matrix \
                = get_permutation_matrix_batch(batch_size, message_length,
                                               num_of_symbols, new_noise_probab)

            message = tf.linalg.matvec(noise_matrix, message)
            # Receiver output.
            preds, preds_logits = receiver(message)

            # Cross_entropy loss.
            loss_ent = tf.nn.softmax_cross_entropy_with_logits(
                labels=labels_one_hot, logits=preds_logits)
            loss_ent = tf.reduce_mean(loss_ent)

            # Kl loss # for high noise KL is sometimes -2e-7.
            loss_kl = tf.maximum(0., kl_coef * latent_cross_entropy)
            loss_entropy = tf.maximum(0.0, ent_coef * latent_entropy)
            loss = loss_ent + loss_kl + loss_entropy

        # Compute gradients and do updates.
        trainable_variables = sender.trainable_variables \
                              + receiver.trainable_variables
        if learned_tau:
            trainable_variables.append(tau_var)

        grads = tape.gradient(loss, trainable_variables)
        optimizer.apply_gradients(zip(grads, trainable_variables))

        # Average accuracy.
        accuracy = tf.reduce_mean(tf.cast(tf.equal(preds, labels), 'float32'))

        return dict(loss=loss,
                    loss_kl=loss_kl,
                    accuracy=accuracy,
                    entropy=latent_entropy
                    )

    @tf.function
    def model_inference(images, noisy=False):
        # Sender output.
        signal_logits, noisy_logits, _, _ = sender(images)

        logits = noisy_logits if noisy else signal_logits

        # Generate message.
        message = softmax_fn(logits, tau_var)

        # Receiver output.
        preds, _ = receiver(message)

        return logits, preds

    # Initialize statistics.
    last = time.time()
    accuracy_smooth = 0
    accuracy_st = 0
    accuracy_nst = 0

    # Main loop
    training_steps = 0
    for epoch in range(total_epochs):

        for batch in ds:
            # set schedulable parameters
            current_noise_probability = noise_schedule_fn(epoch, training_steps)
            straight_through = straight_through_fn(epoch, training_steps)
            gradient_trick = gradient_trick_fn(epoch, training_steps)
            new_noise_probab = new_noise_schedule_fn(epoch, training_steps)

            sender.noise_probability = current_noise_probability
            if not learned_tau:
                tau_var.assign(tau_fn(epoch, training_steps))

            # Unpack a batch.
            images = batch['image']
            labels = tf.stack([batch[feature_name]
                for feature_name in ds_info['features_list_train'][0]])
            labels = tf.transpose(labels, perm=[1, 0])

            # Train on a batch.
            results = learn_on_batch(images, labels, straight_through,
                                     gradient_trick, new_noise_probab)

            # Report results
            results = {key: value.numpy() for key, value in results.items()}
            accuracy_smooth = smooth_coef * accuracy_smooth + (
                        1 - smooth_coef) * results['accuracy']

            # accuracy broken down on straight_through and color / shape
            if straight_through:
                accuracy_st = results['accuracy']
            else:
                accuracy_nst = results['accuracy']
            logger.store(
                Loss=results['loss'],
                LossKL=results['loss_kl'],
                Accuracy=results['accuracy'],
                AccuracyST=accuracy_st,
                AccuracyNST=accuracy_nst,
                Entropy=results['entropy'],
                AccuracySmooth=accuracy_smooth,
                NoiseProbability=current_noise_probability,
                Tau=max(tau_var.numpy(), tau_min)
            )

            # Run evaluation.
            if eval_every and training_steps % eval_every == 0:
                # Test the performance of the model
                eval_dict = eval_cb(model_inference, training_steps)
                for key, value in eval_dict.items():
                    logger.log_tabular(key, value)
                if training_steps / eval_every > 100:
                    sender.save_weights('data_out/checkpoint_sender')
                    receiver.save_weights('data_out/checkpoint_reciver')
                    sys.exit(0)

            # Log progress.
            if training_steps % log_every == 0:
                # log time
                now = time.time()
                time_between_logs = (now - last) / log_every
                last = now

                if dump_debug_observations and training_steps == 0:
                    observations = []
                    for idx, (im, label) in enumerate(
                            zip(images.numpy(), labels.numpy())):
                        im = ((im + 1.0) * 255.0 / 2).astype('uint8')
                        pil_image = Image.fromarray(im)
                        im_draw = ImageDraw.Draw(pil_image)
                        im_draw.text((0, 0), str(label))
                        pil_image.save(rf'observation_test_{idx}.png')
                        observations.append(pil_image)

                    # Hack way of passing list to pass through logger api
                    logger.log_tabular('Observations', observations)

                # Log info about epoch.
                logger.log_tabular('TrainingStep', training_steps)
                logger.log_tabular('Epoch', epoch)
                logger.log_tabular('TimeBetweenLogs', time_between_logs)
                logger.log_tabular('Loss', average_only=log_only_average)
                logger.log_tabular('LossKL', average_only=log_only_average)
                logger.log_tabular('Entropy', average_only=log_only_average)
                logger.log_tabular('Accuracy', average_only=log_only_average)
                logger.log_tabular('AccuracySmooth',
                                   average_only=log_only_average)
                logger.log_tabular('LearningRate',
                                   learning_rate, average_only=True)
                logger.log_tabular('NoiseProbability', average_only=True)
                logger.log_tabular('Tau', average_only=True)
                logger.dump_tabular()

            # Save model
            if save_every and (training_steps + 1) % save_every == 0:
                sender.save_weights('data_out/checkpoint_sender')
                receiver.save_weights('data_out/checkpoint_reciver')

            training_steps += 1
