import numpy as np
import torch as th
import time
import glob
import os
import matplotlib.pyplot as plt
import kernel_variables
import kernel_net
import configuration as cfg
import helper_functions as helpers
import sys
sys.path.append("../../active_tuning")
import active_tuning

plt.style.use('dark_background')


# Hide the GPU(s) in case the user specified to use the CPU in the config file
if cfg.DEVICE == "CPU":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = ""


def evaluate(
    lstm_size=4, noise_ratio_training=0.0, noise_ratio_inference=0.0,
    repeats=range(10), washout=400, skip_error_length=100, 
    use_active_tuning=False, at_tf_rate=0.0, at_tuning_length=15,
    at_tuning_cycles=10, at_eta=0.001, at_beta1=0.9, at_beta2=0.999,
    at_epsilon=1e-8, visualize_results=False):

    # Set globally reachable flags for the training
    cfg.TRAINING = False
    cfg.PK_NUM_LSTM_CELLS = lstm_size
    cfg.AT_LENGTH = at_tuning_length
    cfg.AT_TF_RATE = at_tf_rate
    cfg.AT_CYCLES = at_tuning_cycles
    cfg.AT_ETA = at_eta
    cfg.ADAM_BETA1 = at_beta1
    cfg.ADAM_BETA2 = at_beta2
    cfg.TEACHER_FORCING_STEPS = washout

    # Set device on GPU if specified in the configuration file, else CPU
    device = helpers.determine_device()

    # Compute batch size for the PKs (every PK is processed in a separate batch
    # to parallelize computation)
    pk_batches = cfg.PK_ROWS * cfg.PK_COLS
 
    # Load train data to extract the data statistics
    data = np.array(np.load("data/train/wave_train.npy"), dtype=np.float32)
    std_noise = np.std(data) * noise_ratio_inference

    # Load and prepare the test data
    data = np.array(np.load("data/test/wave_test.npy"), dtype=np.float32)

    # Prepare the data to be in the correct format
    data = np.reshape(data, (data.shape[0], data.shape[1], pk_batches))
    data = np.expand_dims(data, axis=-1)

    # Get the input data (targets are not required)
    data_inputs = data[:, :-1]

    # Convert the numpy array to pytorch tensors
    data_inputs = th.tensor(data_inputs, device=device)

    # Get the range of input data as list to iterate through
    indices = np.arange(len(data_inputs))

    cfg.SEQ_LEN = len(data_inputs[0])
    cfg.CLOSED_LOOP_STEPS = cfg.SEQ_LEN - cfg.TEACHER_FORCING_STEPS

    # Initialize an empty error list to store the rmse scores for all models
    rmse_list = list()

    for rep in repeats:

        # Set up the parameter and tensor classes
        params = kernel_variables.KernelParameters(
            pk_batches=pk_batches,
            device=device
        )
        # tensors1 = kernel_variables.KernelTensors(_params=params)
        tensors = kernel_variables.KernelTensors(params=params)

        # Initialize and set up the kernel network
        net = kernel_net.KernelNetwork(
            params=params,
            tensors=tensors
        )

        # Restore the network by loading the weights saved in the .pt file
        net.load_state_dict(th.load(
            "models/LSTM-" + str(lstm_size) + "_s"
            + "%.2f" % noise_ratio_training + "_r" + str(rep + 1).zfill(2)
            + ".pt",
            map_location=params.device)
        )
        net.eval()

        # Count number of trainable parameters
        pytorch_total_params = sum(
            p.numel() for p in net.parameters() if p.requires_grad
        )

        # Initialize active tuning if required
        at = None
        if params.active_tuning:

            # Accessor function that returns the state which is to be optimized
            at_opt_accessor = lambda out, state: state[2]  # state...
            initial_model_state = [tensors.pk_lat_in, tensors.pk_lstm_c,
                                   tensors.pk_lstm_h]

            at = active_tuning.ActiveTuning(
                model=net,
                initial_model_state=initial_model_state,
                initial_model_output=tensors.pk_dyn_out,
                tuning_length=params.at_length,
                tf_rate=params.at_tf_rate,
                opt_accessor=at_opt_accessor,
                tuning_cycles=params.at_cycles,
                eta=params.at_eta,
                beta1=params.adam_beta1,
                beta2=params.adam_beta2,
                epsilon=params.adam_epsilon
            )

        """
        TESTING
        """

        # Initialize zero-array to store the network outputs
        net_outputs = th.zeros(
            (len(indices), cfg.SEQ_LEN, pk_batches, params.pk_dyn_out_size),
            device=device
        )
        net_targets = th.zeros_like(net_outputs, device=device)
        net_observations = th.zeros_like(net_outputs, device=device)

        # Iterate over all training batches
        for index in indices:

            # Reset the network to clear the previous sequence
            net.reset(pk_num=pk_batches)

            # Get the current network inputs
            net_inputs = data_inputs[index]

            # Inizialize a first zero network input and an initial zero state
            net_input = th.zeros((pk_batches, params.pk_dyn_in_size),
                                 device=device)

            
            init_std = 0.1 if use_active_tuning else 0.0
                
            pk_lat_in = (init_std * th.randn_like(tensors.pk_lat_in)).detach()
            pk_lstm_c = (init_std * th.randn_like(tensors.pk_lstm_c)).detach()
            pk_lstm_h = (init_std * th.randn_like(tensors.pk_lstm_h)).detach()

            state = [pk_lat_in, pk_lstm_c, pk_lstm_h]

            time_start = time.time()

            for t in range(params.seq_len):

                # Get the current observation that should be produced by the
                # model
                target = net_inputs[t, :].detach().clone()

                # Add some noise on the current observation
                obs_noisy = target + std_noise * th.randn(target.size(),
                                                          device=device)
                
                if t < washout:
                    # Perform teacher forcing or active tuning

                    # Perform active tuning (tune the next network input or
                    # state) or teacher forcing (set the current observation as
                    # next network input)
                    if use_active_tuning:

                        # In case of an ideal prediction, net_output should be
                        # target
                        net_output, state = at.tune(
                            observation_t=obs_noisy.detach().clone()
                        )

                    else:

                        net_output, state = net.forward(
                            dyn_in=net_input, state=state
                        )
                        # Teacher forcing for regular network
                        net_input = obs_noisy.detach().clone()

                else:
                    # Perform closed loop prediction (next input is current
                    # output)
                    net_output, state = net.forward(
                        dyn_in=net_input, state=state
                    )
                    net_input = net_output.detach().clone()

                # Store the current network observation, output and target
                net_observations[index, t:t+1] = obs_noisy
                net_outputs[index, t:t+1] = net_output.detach()
                net_targets[index, t:t+1] = target.detach()

        # Get the loss relevant outputs and targets
        net_outputs_loss = net_outputs[:, skip_error_length:].cpu().numpy()
        net_targets_loss = net_targets[:, skip_error_length:].cpu().numpy()

        # Calculate root mean squared error between network outputs and targets
        rmse = np.sqrt(np.mean(np.square(net_outputs_loss - net_targets_loss)))
        rmse_list.append(rmse)

    return rmse_list


if __name__ == "__main__":

    rmse_list = run_testing(
        noise_ratio_training=0.05,
        noise_ratio_inference=0.2,
        visualize_results=False,
        washout=400,
        use_active_tuning=True,  # state[2], state[0]
        at_tuning_length=5,  # 5
        at_tuning_cycles=17,  # 17
        at_eta=0.0001,  # 0.0001
        at_beta1=0.0,  # 0.0
        at_beta2=0.999  # 0.999
    )

    print(rmse_list)
    print(np.mean(rmse_list), np.std(rmse_list))
