import numpy as np
import plot
from solvers.Solver import Solver
from numpy.fft import fft2, ifft2
from tqdm import tqdm

class KSSolver(Solver):
    name = 'kse'

    def __init__(self, settings):
        super().__init__(settings)
        self.save = True
        self.speed_up = 10
        self.time_factor = 0.01
        self.timesteps = 250
        self.grid_dimension = 256 # hard code so I don't have to keep changing back and forth
        self.domain_mult = settings['domain_mult']

    def generate_operator(self):
        # Sadly an Euler scheme is woefully ill-equipped for this equation.
        # Instead, we implement ETDRK4, together with a clever contour integration modification.
        # (Which is a 4th order Runge-Kutta)
        return np.eye(self.grid_dimension ** 2)

    def calculate_fields(self):
        self.op_powers = np.array([self.op])

        # Spatial grid and initial condition
        n = self.grid_dimension  # Number of grid points in each dimension
        l = self.domain_mult * np.pi  # Domain size
        x = l * np.arange(n) / n
        y = l * np.arange(n) / n
        x, y = np.meshgrid(x, y)

        u = np.random.randn(n, n)
        # u = np.cos(x / 7) + np.cos(y / 7) + 0.1 * np.random.randn(n, n) # Initial condition
        v = fft2(u)

        # Wavenumbers in 2D
        kx = np.fft.fftfreq(n, d=l / (2 * np.pi * n))
        ky = np.fft.fftfreq(n, d=l / (2 * np.pi * n))
        kx, ky = np.meshgrid(kx, ky)

        k2 = kx ** 2 + ky ** 2  # Laplacian in Fourier space
        k4 = k2 ** 2  # Biharmonic in Fourier space

        # Precompute ETDRK4 scalar quantities
        print("Precomputing ETDRK4 quantities...")
        h = self.time_factor  # Time step
        linear = k2 - k4  # Linear term in Fourier space
        e = np.exp(h * linear)
        e_2 = np.exp(h * linear / 2)
        m = 16  # Number of points along which we complete the contour integration
        r = (1j * np.pi * (np.arange(1, m + 1) - 0.5) / m)
        lr = h * linear[:, :, None] + r[None, None, :]
        q = h * np.mean((np.exp(lr / 2) - 1) / lr, axis=2).real
        f1 = h * np.mean((-4 - lr + np.exp(lr) * (4 - 3 * lr + lr ** 2)) / lr ** 3, axis=2).real
        f2 = h * np.mean((2 + lr + np.exp(lr) * (-2 + lr)) / lr ** 3, axis=2).real
        f3 = h * np.mean((-4 - 3 * lr - lr ** 2 + np.exp(lr) * (4 - lr)) / lr ** 3, axis=2).real

        uu = [u]
        t_max = self.timesteps
        n_max = int(t_max / h)

        # Define non-linear term of KSE. Note we use |grad(u)|^2 and not the alternate form u * grad(u)
        def nonlinear(freq):
            grad_x = ifft2(1j * kx * freq).real
            grad_y = ifft2(1j * ky * freq).real
            grad_u_2 = grad_x ** 2 + grad_y ** 2
            return -fft2(grad_u_2)

        bar = tqdm(range(1, n_max))
        for _ in bar:
            nl_v = nonlinear(v)
            a = e_2 * v + q * nl_v
            nl_a = nonlinear(a)
            b = e_2 * v + q * nl_a
            nl_b = nonlinear(b)
            c = e_2 * a + q * (2 * nl_b - nl_v)
            nl_c = nonlinear(c)
            v = e * v + nl_v * f1 + 2 * (nl_a + nl_b) * f2 + nl_c * f3

            v[0, 0] = 0  # Zero out the average
            u = ifft2(v).real
            uu.append(u)
            bar.set_postfix({'Energy': np.mean(u ** 2)})
            v = fft2(u)  # Symmetrize

        uu = np.stack(uu)
        self.u = uu.reshape(*uu.shape[:-2], -1)
        self.m_u = self.calculate_masked_u_field(self.u)
        if self.save:
            print('Saving video...')
            plot.save_tensor(uu[::self.speed_up], f'kse/test/{self.seed}')
