"""
In this file we generate wave data to train a network. Waves are generate by
using the two dimensional wave equation.
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
import wave_generator

#
# GLOBAL VARIABLES

#
# Training and testing sequence lengths and number of samples
train_sequence_length = 80 + 1
test_sequence_length = 400 + 1
test_long_sequence_length = 2000 + 1
large_sequence_length = 2000 + 1

train_samples = 200
test_samples = 20
test_long_samples = 10
large_samples = 1

#
# Simulation parameters
dt = 0.1  # Temporal step size
dx = 1  # Step size in x-direction
dy = 1  # Step size in y-direction

#
# Field parameters
width = 40  # 16  # Width of the simulated field in pixels
height = 40  # 16  # Height of the simulated field in pixels

#
# Wave parameters
wave_width_x = 0.5  # Width of the wave in x-direction
wave_width_y = 0.5  # Width of the wave in y-direction
amplitude = 0.34  # Amplitude of the wave
velocity = 3.0  # The velocity of the wave
waves = 1  # The number of waves that propagate simultaneously in one sequence
damp = 1.0  # How much a wave is dampened (decaying over time)

#
# Other parameters
save_data = True  # Shall the generated data be saved to file
visualize = False  # Create a plot and animation of the wave


#############
# FUNCTIONS #
#############

def animate(_t, im, field):
    im.set_array(field[_t, :, :])
    return im

def generate_samples(number_of_samples, sequence_length, wave_generator):

    # Initialize empty list to store the samples
    samples = []

    # Create the desired number of samples
    for sample_number in range(number_of_samples):

        # Print a progress statement to the console for every 20 generated
        # files
        if (sample_number + 1) % 20 == 0:
            print("Generating file " + str(sample_number + 1) + "/"
                  + str(number_of_samples))

        # Initialize the wave field as two-dimensional zero-array
        field = wave_generator.generate_wave(sequence_length=sequence_length)
        samples.append(field)

        if visualize:
            # Plot the wave activity at one position
            fig, ax = plt.subplots(1, 1, figsize=[8, 2])
            ax.plot(range(sequence_length), field[:, 5, 5])
            ax.set_xlabel("Time")
            ax.set_ylabel("Wave amplitude")
            ax.set_xlim([0, 299])
            plt.tight_layout()
            plt.show()

            # Animate the overall wave
            fig, ax = plt.subplots(1, 1, figsize=[6, 6])
            im = ax.imshow(field[0, :, :], vmin=-0.6, vmax=0.6, cmap='Blues')
            anim = animation.FuncAnimation(fig,
                                           animate,
                                           frames=sequence_length,
                                           fargs=(im, field),
                                           interval=200)
            plt.axis('off')
            plt.tight_layout()
            plt.show()

    return np.array(samples)


##########
# SCRIPT #
##########

# Set up paths to store the data
path_train = "data/train/"
path_test = "data/test/"
path_test_long = "data/test_long/"
path_large = "data/large/"

# Set up the wave generator
wave_generator = wave_generator.WaveGenerator(width=width, height=height)

# Create the data
print("Generation of train data...")
train_data = generate_samples(number_of_samples=train_samples,
                              sequence_length=train_sequence_length,
                              wave_generator=wave_generator)

print("Generation of test data...")
test_data = generate_samples(number_of_samples=test_samples,
                             sequence_length=test_sequence_length,
                             wave_generator=wave_generator)

print("Generation of long test data...")
test_long_data = generate_samples(number_of_samples=test_long_samples,
                              sequence_length=test_long_sequence_length,
                              wave_generator=wave_generator)

print("Generation of large (40x40) test data...")
large_data = generate_samples(number_of_samples=large_samples,
                              sequence_length=large_sequence_length,
                              wave_generator=wave_generator)

# Create appropriate train and test file names including the MSO-complexity
file_name_train = "wave_train"
file_name_test = "wave_test"
file_name_test_long = "wave_test_long"
file_name_large = "wave_large"

# Create target directories of yet existing
os.makedirs(path_train, exist_ok=True)
os.makedirs(path_test, exist_ok=True)
os.makedirs(path_test_long, exist_ok=True)
os.makedirs(path_large, exist_ok=True)

# Write the samples containing the current MSO-complexity to file
np.save(path_train + file_name_train, np.array(train_data))
np.save(path_test + file_name_test, np.array(test_data))
np.save(path_test_long + file_name_test_long, np.array(test_long_data))
np.save(path_large + file_name_large, np.array(large_data))

print("Done.")
