import numpy as np
import torch as th
import os
import time
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import configuration as cfg


def evaluate(net, net_input, net_label, params, tensors, pk_batches,
             criterion=None, optimizer=None, testing=False, at=None):
    """
    This function evaluates the network for given data and optimizes the
    weights if an optimizer is provided.
    :param net: The network
    :param net_input: The input to the network
    :param net_label: The target for the network
    :param params: The parameters of the network
    :param tensors: The tensors of the network
    :param pk_batches: The number of batches for the PKs
    :param criterion: The criterion to measure the error
    :param optimizer: The optimizer to optimize the weights
    :param testing: Bool that determines weather network is being tested
    :param at: The Active Tuning class object
    :return: The error, net inputs, net labels and net outputs
    """

    seq_len = len(net_input)

    # Set up an array of zeros to store the network outputs
    net_outputs = th.zeros(size=(seq_len,
                                 pk_batches,
                                 params.pk_dyn_out_size))

    if optimizer:
        # Set the gradients back to zero
        optimizer.zero_grad()

    # Reset the network to clear the previous sequence
    net.reset(pk_num=pk_batches)

    pk_lat_in = th.zeros_like(tensors.pk_lat_in)
    pk_lstm_c = th.zeros_like(tensors.pk_lstm_c)
    pk_lstm_h = th.zeros_like(tensors.pk_lstm_h)
    state = [pk_lat_in, pk_lstm_c, pk_lstm_h]

    # Iterate over the whole sequence of the training example and perform a
    # forward pass
    for t in range(seq_len):

        # Prepare the network input for this sequence step
        if testing and t > cfg.TEACHER_FORCING_STEPS:
            
            #
            # Closed loop - receiving the output of the last time step as
            # input
            dyn_net_in_step = net_outputs[t - 1].detach().numpy()
            pk_dyn_out, state = net.forward(dyn_in=dyn_net_in_step,
                                            state=state)
            
        else:
            #
            # Teacher forcing

            # If desired, perform active tuning to adapt the (noisy) input such
            # that it produces the desired (noise-free) output
            if testing and params.active_tuning:
                
                #
                # Active Tuning
                pk_dyn_out, state = at.tune(
                    observation_t=th.tensor(net_input[t + 1, :, 0:1],
                                            device=params.device),
                    reset_optimizer=cfg.ADAM_RESET_OPTIMIZER
                )

            else:
                
                # Set the dynamic input for this iteration
                dyn_net_in_step = net_input[t, :, :params.pk_dyn_out_size]
                pk_dyn_out, state = net.forward(dyn_in=dyn_net_in_step,
                                                state=state)

        # Store the output of the network for this sequence step
        net_outputs[t] = pk_dyn_out

    mse = None

    if criterion:

        # Get the mean squared error from the evaluation list
        mse = criterion(net_outputs, net_label)
        # Alternatively, the mse can be calculated 'manually'
        # mse = th.mean(th.pow(net_outputs - th.from_numpy(net_label), 2))

        if optimizer:
            mse.backward()
            optimizer.step()

    return mse, net_outputs


def determine_device():
    """
    This function evaluates whether a GPU is accessible at the system and
    returns it as device to calculate on, otherwise it returns the CPU.
    :return: The device where tensor calculations shall be made on
    """
    device = th.device("cuda" if th.cuda.is_available() else "cpu")
    # print("Using device:", device, "\n")

    return device


def save_model_to_file(lstm_size, noise_ratio, rep, net):
    """
    This function writes the model weights along with the network configuration
    and current performance to file.
    :param model_type: ...
    :param hidden_size: The number of cells in the LSTM layer
    :param noise_ratio: ...
    :param rep: The n-th repetition of the currently trained model
    :param net: The actual model
    :return: Nothing
    """

    os.makedirs("models/", exist_ok=True)

    model_name = "LSTM-" + str(lstm_size) + "_s" \
                 + "%.2f" % noise_ratio + "_r" + str(rep + 1).zfill(2) + ".pt"

    # Save model weights to file
    th.save(net.state_dict(), "models/" + model_name)


def plot_kernel_activity(ax, label, net_out, net_in=None, make_legend=False):
    """
    This function displays the wave activity of a single kernel.
    :param ax: The plot where the activity shall be displayed in
    :param label: The label for the wave (ground truth)
    :param net_out: The network output
    :param net_in: The network input
    :param make_legend: Boolean that indicates weather a legend shall be created
    """

    central_kernel = (cfg.PK_ROWS * cfg.PK_COLS) // 2
    #central_kernel = 20
    central_kernel = np.random.randint(low=0, high=central_kernel)

    if net_in is not None:
        ax.plot(range(len(net_in)), net_in[:, central_kernel, 0],
                label='Network input', color='green')
    ax.plot(range(len(label)), label[:, central_kernel, 0],
            label='Target', color='deepskyblue')
    ax.plot(range(len(net_out)), net_out[:, central_kernel, 0],
            label='Network output', color='red', linestyle='dashed')
    # if net_in is None:
    yticks = ax.get_yticks()[1:-1]
    ax.plot(np.ones(len(yticks)) * cfg.TEACHER_FORCING_STEPS, yticks,
            color='white', linestyle='dotted',
            label='End of teacher forcing')
    ax.set_title(central_kernel)
    if make_legend:
        ax.legend()


def animate_2d_wave(net_label, net_outputs, net_inputs=None):
    """
    This function visualizes the spatio-temporally expanding wave
    :param net_label: The corresponding labels
    :param net_outputs: The network output
    :param net_inputs: The network inputs
    :return: The animated plot of the 2d wave
    """

    num_axes = 2

    # Bring the data into a format that can be displayed as heatmap
    data = np.reshape(net_outputs, [len(net_outputs),
                                    cfg.PK_ROWS,
                                    cfg.PK_COLS,
                                    len(net_outputs[0, 0])])
    net_label = np.reshape(net_label, [len(net_label),
                                       cfg.PK_ROWS,
                                       cfg.PK_COLS,
                                       len(net_label[0, 0])])

    if net_inputs is not None:
        net_inputs = np.reshape(net_inputs, [len(net_inputs),
                                             cfg.PK_ROWS,
                                             cfg.PK_COLS,
                                             len(net_inputs[0, 0])])
        num_axes = 3

    # Define a grid size that shall be visualized
    gs1 = 0
    gs2 = cfg.PK_ROWS

    # First set up the figure, the axis, and the plot element we want to
    # animate
    fig, axes = plt.subplots(1, num_axes, figsize=[6*num_axes, 6], dpi=100)
    im1 = axes[0].imshow(net_label[0, gs1:gs2, gs1:gs2, 0], vmin=-0.8, vmax=0.8,
                         cmap='Blues')

    # Visualize the obstacle if there is one
    txt1 = axes[0].text(0, axes[0].get_yticks()[0], 't = 0', fontsize=20,
                        color='white')
    axes[0].set_title("Network Output")

    # In the subfigure on the right hand side, visualize the true data
    im2 = axes[1].imshow(net_label[0, gs1:gs2, gs1:gs2, 0], vmin=-0.8, vmax=0.8,
                         cmap='Blues')
    axes[1].set_title("Ground Truth")

    im3 = None
    if net_inputs is not None:
        im3 = axes[2].imshow(net_inputs[0, gs1:gs2, gs1:gs2, 0], vmin=-0.8,
                             vmax=0.8, cmap="Blues")
        axes[2].set_title("Network Input")

    anim = animation.FuncAnimation(fig, animate, frames=len(data),
                                   fargs=(cfg.TEACHER_FORCING_STEPS, data, im1,
                                          im2, im3, txt1, gs1, gs2, net_label,
                                          net_inputs),
                                   interval=1)

    return anim


def animate(_i, _teacher_forcing_steps, _data, _im1, _im2, _im3, _txt1, _gs1,
            _gs2, _net_label, _net_inputs):

    # Pause the simulation briefly when switching from teacher forcing to
    # closed loop prediction
    if _i == _teacher_forcing_steps:
        time.sleep(0.5)
    elif _i < 150:
        time.sleep(0.01)

    # Set the pixel values of the image to the data of timestep _i
    _im1.set_array(_data[_i, _gs1:_gs2, _gs1:_gs2, 0])
    if _i < len(_net_label) - 1:
        _im2.set_array(_net_label[_i, _gs1:_gs2, _gs1:_gs2, 0])
        if _im3 is not None:
            _im3.set_array(_net_inputs[_i, _gs1:_gs2, _gs1:_gs2, 0])

    # Display the current timestep in text form in the plot
    if _i < _teacher_forcing_steps:
        _txt1.set_text('Teacher forcing, t = ' + str(_i))
    else:
        _txt1.set_text('Closed loop prediction, t = ' + str(_i))

    return _im1, _im2, _im3
