import numpy as np


class WaveGenerator():

    def __init__(self, width, height, waves=1, wave_width_x=0.5,
                 wave_width_y=0.5, amplitude=0.5, velocity=3.0, damp=1.0,
                 dx=1, dy=1, dt=0.1):
        
        self.dx = dx
        self.dy = dy
        self.dt = dt

        self.width = width
        self.height = height

        self.waves = waves
        self.wave_width_x = wave_width_x
        self.wave_width_y = wave_width_y
        self.amplitude = amplitude
        self.velocity = velocity
        self.damp = damp

    def generate_wave(self, sequence_length):

        # Initialize the wave field as two-dimensional zero-array
        self.field = np.zeros([sequence_length, self.width, self.height])

        for wave in range(self.waves):
            # Generate a random point in the field where the impulse will be
            # initialized
            start_pt = np.random.randint(0, self.width, 2)

            # Compute the initial field activity by applying a 2D gaussian
            # around the start point
            for x in range(self.width):
                for y in range(self.height):
                    self.field[0, x, y] += self.f(
                        x, y, self.wave_width_x, self.wave_width_y,
                        self.amplitude, start_pt
                    )

            # Iterate over all time steps to compute the activity at each
            # position in the grid over all time steps
            for t in range(sequence_length - 1):

                # Iterate over all values in the field and update them
                for x in range(self.width):
                    for y in range(self.height):
                        self.field[t + 1, x, y] = self.u(_t=t, _x=x, _y=y)

        return self.field

    def f(self, _x, _y, _varx, _vary, _a, _start_pt):
        """
        Function to set the initial activity of the field. We use the Gaussian
        bell curve to initialize the field smoothly.
        :param _x: The x-coordinate of the field (j, running over width)
        :param _y: The y-coordinate of the field (i, running over height))
        :param _varx: The variance in x-direction
        :param _vary: The variance in y-direction
        :param _a: The amplitude of the wave
        :return: The initial activity at (x, y)
        """
        x_part = ((_x - _start_pt[0])**2) / (2 * _varx)
        y_part = ((_y - _start_pt[1])**2) / (2 * _vary)
        return _a * np.exp(-(x_part + y_part))

    def g(self, _x, _y, _varx, _vary, _a):
        """
        Function to determine the changes over time in the field
        :param _x: The x-coordinate of the field (j, running over width)
        :param _y: The y-coordinate of the field (i, running over height))
        :param _varx: The variance in x-direction
        :param _vary: The variance in y-direction
        :param _a: The amplitude of the wave
        :return: The changes over time in the field at (x, y)
        """
        # x_part = _x * f(_x, _y, _varx, _vary, _a)
        # y_part = _y * f(_x, _y, _varx, _vary, _a)
        # return (x_part + y_part) / 2.
        return 0.0

    def u(self, _t, _x, _y):
        """
        Function to calculate the field activity in time step t at (x, y)
        :param _t: The current time step
        :param _x: The x-coordinate of the field (j, running over width)
        :param _y: The y-coordinate of the field (i, running over height))
        :param _c: The wave velocity
        :return: The field activity at position x, y in time step t.
        """

        # Compute changes in x- and y-direction
        dxxu = self.dxx_u(_t, _x, _y)
        dyyu = self.dyy_u(_t, _x, _y)

        # Get the activity at x and y in time step t
        u_t = self.field[_t, _x, _y]

        # Catch initial condition, where there is no value of the field at
        # time step (t-1) yet
        if _t == 0:
            u_t_1 = self.dt_u(_x, _y)
            # u_t_1 = 0.0
        else:
            u_t_1 = self.field[_t - 1, _x, _y]

        # Incorporate the changes in x- and y-direction and return the activity
        return self.damp * \
            (((self.velocity**2) * (self.dt ** 2)) * (dxxu + dyyu)
             + 2 * u_t - u_t_1)

    def dxx_u(self, _t, _x, _y):
        """
        The second derivative of u to x. Computes the lateral activity change
        in x-direction.
        :param _t: The current time step
        :param _x: The x-coordinate of the field (j, running over width)
        :param _y: The y-coordinate of the field (i, running over height))
        :return: Field activity in t, considering changes in x-direction
        """

        # Boundary condition at left end of the field
        if _x == 0:
            dx_left = 0.0
        else:
            dx_left = self.field[_t, _x - self.dx, _y]

        # Boundary condition at right end of the field
        if _x == self.width - 1:
            dx_right = 0.0
        else:
            dx_right = self.field[_t, _x + self.dx, _y]

        # Calculate change in x-direction and return it
        ut_dx = dx_right - 2*self.field[_t, _x, _y] + dx_left

        return ut_dx / np.square(self.dx)

    def dyy_u(self, _t, _x, _y):
        """
        The second derivative of u to y. Computes the lateral activity change in
        y-direction.
        :param _t: The current time step
        :param _x: The x-coordinate of the field (j, running over width)
        :param _y: The y-coordinate of the field (i, running over height))
        :return: Field activity in t, considering changes in y-direction
        """

        # Boundary condition at top end of the field
        if _y == 0:
            dy_above = 0.0
        else:
            dy_above = self.field[_t, _x, _y - self.dy]

        # Boundary condition at bottom end of the field
        if _y == self.height - 1:
            dy_below = 0.0
        else:
            dy_below = self.field[_t, _x, _y + self.dy]

        # Calculate change in y-direction and return it
        ut_dy = dy_below - 2*self.field[_t, _x, _y] + dy_above

        return ut_dy / np.square(self.dy)

    def dt_u(self, _x, _y):
        """
        First derivative of u to t, only required in the very first time step to
        compute u(-dt, x, y).
        :param _x: The x-coordinate of the field (j, running over width)
        :param _y: The y-coordinate of the field (i, running over height))
        :return: The value of the field at (t-1), x, y
        """
        
        derivative = 2 * self.dt * self.g(_x, _y, self.wave_width_x,
                                          self.wave_width_y, self.amplitude)

        return self.field[1, _x, _y] - derivative