from solvers.Solver import Solver

import numpy as np


class WaveSolver(Solver):
    name = 'wave'

    def __init__(self, settings):
        super().__init__(settings)
        self.u_dot[0, :] = np.zeros(self.u[0].shape)

    def generate_operator(self):
        # b = self.a_matrix(self.seed + 1).flatten()
        # a = np.diag(b)
        np.random.seed(self.op_seed)
        a = np.abs(np.diag(np.random.randn(self.grid_dimension ** 2)))

        l = self.grid_dimension ** 2
        laplacian = self.generate_laplacian(a)
        matrix = np.zeros((2 * l, 2 * l))
        matrix[l:, :l] = laplacian  # Place laplacian in bottom left corner
        matrix[:l, l:] = np.identity(l)  # Place identity in top right corner
        return np.identity(2 * l) + self.gamma * matrix

    def calculate_fields_alt(self):
        self.op = self.generate_operator()

        self.generate_operator_powers()
        self.op_powers = self.tensor_wrapper(self.op_powers)

        stacked = np.hstack((self.u, self.u_dot))
        for k in range(self.timesteps):
            stacked[k, :] = self.mult(self.op_powers[k], stacked[0])
        self.u = stacked[:, :self.grid_dimension ** 2]
        self.u_dot = stacked[:, self.grid_dimension ** 2:]
        self.m_u = self.calculate_masked_u_field(self.u)

    def calculate_fields(self):
        if self.op is None:
            self.op = self.generate_operator()

        stacked = np.hstack((self.u, self.u_dot))
        for i in range(1, self.timesteps):
            stacked[i, :] = self.mult(self.op, stacked[i - 1, :])
        self.u = stacked[:, :self.grid_dimension ** 2]
        self.u_dot = stacked[:, self.grid_dimension ** 2:]

        self.m_u = self.calculate_masked_u_field(self.u)