import numpy as np
import torch as th
import time
import glob
import os
import matplotlib.pyplot as plt
import kernel_variables
import kernel_net
import configuration as cfg
import helper_functions as helpers
import wave_plot
import sys
sys.path.append("../../active_tuning")
import active_tuning

#plt.style.use('dark_background')


# Hide the GPU(s) in case the user specified to use the CPU in the config file
if cfg.DEVICE == "CPU":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = ""


def run_testing(
    lstm_size=4, noise_ratio_training=0.0, noise_ratio_inference=0.0,
    repeats=range(10), washout=400, skip_error_length=100, 
    use_active_tuning=False, at_tf_rate=0.0, at_tuning_length=15,
    at_tuning_cycles=10, at_eta=0.001, at_beta1=0.9, at_beta2=0.999,
    at_epsilon=1e-8, visualize_results=False, create_plot=False,
    create_video_frames=False):

    # Set globally reachable flags for the training
    cfg.TRAINING = False
    cfg.PK_NUM_LSTM_CELLS = lstm_size
    cfg.AT_LENGTH = at_tuning_length
    cfg.AT_TF_RATE = at_tf_rate
    cfg.AT_CYCLES = at_tuning_cycles
    cfg.AT_ETA = at_eta
    cfg.ADAM_BETA1 = at_beta1
    cfg.ADAM_BETA2 = at_beta2
    cfg.TEACHER_FORCING_STEPS = washout

    # Set device on GPU if specified in the configuration file, else CPU
    device = helpers.determine_device()

    # Compute batch size for the PKs (every PK is processed in a separate batch
    # to parallelize computation)
    pk_batches = cfg.PK_ROWS * cfg.PK_COLS
 
    # Load train data to extract the data statistics
    data = np.array(np.load("data/train/wave_train.npy"), dtype=np.float32)
    std_noise = np.std(data) * noise_ratio_inference

    # Load and prepare the test data
    data = np.array(np.load("data/test/wave_test.npy"), dtype=np.float32)
    if cfg.PK_ROWS == 40:
        data = np.array(np.load("data/large/wave_large.npy"), dtype=np.float32)

    # Prepare the data to be in the correct format
    data = np.reshape(data, (data.shape[0], data.shape[1], pk_batches))
    data = np.expand_dims(data, axis=-1)

    # Get the input data (targets are not required)
    data_inputs = data[:, :-1]

    # Convert the numpy array to pytorch tensors
    data_inputs = th.tensor(data_inputs, device=device)

    # Get the range of input data as list to iterate through
    indices = np.arange(len(data_inputs))

    cfg.SEQ_LEN = len(data_inputs[0])
    cfg.CLOSED_LOOP_STEPS = cfg.SEQ_LEN - cfg.TEACHER_FORCING_STEPS

    # Initialize an empty error list to store the rmse scores for all models
    rmse_list = list()

    for rep in repeats:

        # Set up the parameter and tensor classes
        params = kernel_variables.KernelParameters(
            pk_batches=pk_batches,
            device=device
        )
        # tensors1 = kernel_variables.KernelTensors(_params=params)
        tensors = kernel_variables.KernelTensors(params=params)

        # Initialize and set up the kernel network
        net = kernel_net.KernelNetwork(
            params=params,
            tensors=tensors
        )

        # Restore the network by loading the weights saved in the .pt file
        net.load_state_dict(th.load(
            "models/LSTM-" + str(lstm_size) + "_s"
            + "%.2f" % noise_ratio_training + "_r" + str(rep + 1).zfill(2)
            + ".pt",
            map_location=params.device)
        )
        net.eval()

        # Count number of trainable parameters
        pytorch_total_params = sum(
            p.numel() for p in net.parameters() if p.requires_grad
        )
        print("Trainable model parameters:", pytorch_total_params)

        # Initialize active tuning if required
        at = None
        if params.active_tuning:

            # Initialize and set up the kernel network
            net_at = kernel_net.KernelNetwork(
                params=params,
                tensors=tensors
            )

            # Restore the network by loading the weights saved in the .pt file
            net_at.load_state_dict(th.load(
                "models/LSTM-" + str(lstm_size) + "_s"
                + "%.2f" % noise_ratio_training + "_r" + str(rep + 1).zfill(2)
                + ".pt",
                map_location=params.device)
            )
            net_at.eval()

            # Accessor function that returns the state which is to be optimized
            at_opt_accessor = lambda out, state: state[2]  # state...
            initial_model_state = [tensors.pk_lat_in, tensors.pk_lstm_c,
                                   tensors.pk_lstm_h]

            at = active_tuning.ActiveTuning(
                model=net_at,
                initial_model_state=initial_model_state,
                initial_model_output=tensors.pk_dyn_out,
                tuning_length=params.at_length,
                tf_rate=params.at_tf_rate,
                opt_accessor=at_opt_accessor,
                tuning_cycles=params.at_cycles,
                eta=params.at_eta,
                beta1=params.adam_beta1,
                beta2=params.adam_beta2,
                epsilon=params.adam_epsilon
            )

        """
        TESTING
        """

        # Initialize zero-array to store the network outputs
        net_outputs = th.zeros(
            (len(indices), cfg.SEQ_LEN, pk_batches, params.pk_dyn_out_size),
            device=device
        )
        net_outputs_at = th.zeros_like(net_outputs, device=device)
        net_targets = th.zeros_like(net_outputs, device=device)
        net_observations = th.zeros_like(net_outputs, device=device)

        # Iterate over all training batches
        for index in indices:

            print("Processing index " + str(index) + "/" + str(len(indices))
                  + " in rep " + str(rep) + "/" + str(len(repeats)))

            # Reset the network to clear the previous sequence
            net.reset(pk_num=pk_batches)

            # Get the current network inputs
            net_inputs = data_inputs[index]

            # Inizialize a first zero network input and an initial zero state
            net_input = th.zeros((pk_batches, params.pk_dyn_in_size),
                                 device=device)

            
            init_std = 0.1 if use_active_tuning else 0.0
                
            pk_lat_in = (init_std * th.randn_like(tensors.pk_lat_in)).detach()
            pk_lstm_c = (init_std * th.randn_like(tensors.pk_lstm_c)).detach()
            pk_lstm_h = (init_std * th.randn_like(tensors.pk_lstm_h)).detach()

            state = [pk_lat_in, pk_lstm_c, pk_lstm_h]
            state_at = [pk_lat_in.clone().detach(),
                        pk_lstm_c.clone().detach(),
                        pk_lstm_h.clone().detach()]

            time_start = time.time()

            for t in range(params.seq_len):

                # Get the current observation that should be produced by the
                # model
                target = net_inputs[t, :].detach().clone()

                # Add some noise on the current observation
                obs_noisy = target + std_noise * th.randn(target.size(),
                                                          device=device)
                
                if t < washout:
                    # Perform teacher forcing or active tuning

                    # Perform active tuning (tune the next network input or
                    # state) or teacher forcing (set the current observation as
                    # next network input)
                    if use_active_tuning:

                        print("Tuning time step " + str(t + 1) + "/"
                              + str(washout))

                        # In case of an ideal prediction, net_output should be
                        # target
                        net_output_at, state_at = at.tune(
                            observation_t=obs_noisy.detach().clone()
                        )

                        net_output, state = net.forward(
                            dyn_in=net_input, state=state
                        )
                        net_input = obs_noisy.detach().clone()

                    else:

                        net_output_at = th.tensor(0.0)

                        net_output, state = net.forward(
                            dyn_in=net_input, state=state
                        )
                        # Teacher forcing for regular network
                        net_input = obs_noisy.detach().clone()

                else:
                    # Perform closed loop prediction (next input is current
                    # output)
                    net_output, state = net.forward(
                        dyn_in=net_input, state=state
                    )
                    net_input = net_output.detach().clone()

                # Store the current network output, observation and target
                net_outputs[index, t:t+1] = net_output.detach()
                net_outputs_at[index, t:t+1] = net_output_at.detach()
                net_observations[index, t:t+1] = obs_noisy
                net_targets[index, t:t+1] = target.detach()

            print("Took " + str(time.time() - time_start) + " seconds.")

            if create_plot:
                wave_plot.generate_plot(
                    targets=net_targets[index, :200].detach().cpu().numpy(),
                    net_outputs=net_outputs[index, :200].detach().cpu().numpy(),
                    net_outputs_at=net_outputs_at[index, :200].detach().cpu().numpy(),
                    observations=net_observations[index, :200].detach().cpu().numpy()
                )

            if create_video_frames:
                wave_plot.create_video_frames(
                    targets=net_targets[index, :washout].detach().cpu().numpy(),
                    net_outputs=net_outputs[index, :washout].detach().cpu().numpy(),
                    net_outputs_at=net_outputs_at[index, :washout].detach().cpu().numpy(),
                    observations=net_observations[index, :washout].detach().cpu().numpy()
                )

            if visualize_results:

                print(th.sqrt(th.mean((net_outputs[index, skip_error_length:] -
                                       net_targets[index, skip_error_length:])**2)))

                # Plot the wave activity
                fig, axes = plt.subplots(2, 2, figsize=[10, 10], sharex="all")
                for i in range(2):
                    for j in range(2):
                        make_legend = True if (i == 0 and j == 0) else False
                        helpers.plot_kernel_activity(
                            ax=axes[i, j],
                            label=net_targets[index].detach().cpu(),
                            net_out=net_outputs[index].detach().cpu(),
                            net_in=net_observations[index].detach().cpu(),
                            make_legend=make_legend
                        )
                fig.suptitle('Model ' + cfg.MODEL_NAME, fontsize=12)
                plt.show()
                plt.close()

                # Visualize and animate the propagation of the 2d wave
                anim = helpers.animate_2d_wave(net_targets[index].detach().cpu(),
                                               net_outputs[index].detach().cpu(),
                                               net_observations[index].detach().cpu())
                # anim = helpers.animate_2d_wave(net_label, net_outputs)
                plt.show()
                plt.close()
        
        # Get the loss relevant outputs and targets
        net_outputs_loss = net_outputs[:, skip_error_length:].cpu().numpy()
        net_targets_loss = net_targets[:, skip_error_length:].cpu().numpy()

        # Calculate root mean squared error between network outputs and targets
        rmse = np.sqrt(np.mean(np.square(net_outputs_loss - net_targets_loss)))
        rmse_list.append(rmse)

        print(rmse)

    return rmse_list


if __name__ == "__main__":

    rmse_list = run_testing(
        noise_ratio_training=0.05,
        noise_ratio_inference=1.0,
        visualize_results=True,
        create_plot=False,
        create_video_frames=True,
        washout=200,
        use_active_tuning=True,
        at_tuning_length=7,
        at_tuning_cycles=30,
        at_eta=0.00005,
        at_beta1=0.0,
        at_beta2=0.999
    )

    print(rmse_list)
    print(np.mean(rmse_list), np.std(rmse_list))
