""" Kuramoto–Sivashinsky Equation
Simulates the KS equation in one dimension.
Supports PyTorch, TensorFlow and Jax; select backend via import statement.
"""
from typing import Any, Tuple

from matplotlib import pyplot as plt

from experiment import Experiment, train_network, optimize_directly, visualize_training, visualize_dataset, visualize_optimizations, visualize_results, train_supervised, train_surrogate, \
    neural_adjoint, visualize_param_trj, refine_directly, plot_by_dset_size, paper_plot, get_time_to_train, direct_compare, plot_parameter_trajectories, plot_all_curves
from phi.torch.flow import *


def kuramoto_sivashinsky(u: Grid, non_lin=.5, dt=.5, forcing=0.):
    # --- Operators in Fourier space ---
    non_lin *= -1.j
    frequencies = math.fftfreq(u.resolution) / u.dx
    lin_op = frequencies ** 2 - (1j * frequencies) ** 4  # Fourier operator for linear terms. You'd think that 1j**4 == 1 but apparently the rounding errors have a major effect here even with FP64...
    inv_lin_op = math.divide_no_nan(1, lin_op)  # Removes f=0 component but there is no noticeable difference
    exp_lin_op = math.exp(lin_op * dt)  # time evolution operator for linear terms in Fourier space
    # --- RK2 for non-linear terms, exponential time-stepping for linear terms ---
    non_lin_current = non_lin * frequencies * math.fft(u.values ** 2)
    u_intermediate = exp_lin_op * math.fft(u.values) + non_lin_current * (exp_lin_op - 1) * inv_lin_op  # intermediate for RK2
    non_lin_intermediate = non_lin * frequencies * math.fft(math.ifft(u_intermediate).real ** 2)
    u_new = u_intermediate + (non_lin_intermediate - non_lin_current) * (exp_lin_op - 1 - lin_op * dt) * (1 / dt * inv_lin_op ** 2)
    u_new += forcing * u.with_values(initial).values
    return u.with_values(math.ifft(u_new).real.vector['x'])


def initial(x: math.Tensor):
    return math.cos(x) - 0.1 * math.cos(x / 16) * (1 - 2 * math.sin(x / 16))


class KSForced(Experiment):

    def __init__(self, steps):
        self._steps = steps

    def __repr__(self):
        return f"{self.__class__.__name__}_t{self._steps}"

    def generate_problem(self, batch_dims: Shape, test: bool) -> Tuple[Any, Any, Tensor]:
        non_lin = math.random_uniform(batch_dims, low=-1, high=1)
        forcing = math.random_uniform(batch_dims, low=-1, high=1)
        initial_state = CenteredGrid(Noise(batch_dims, scale=1.5, smoothness=1.4), x=128, bounds=Box(x=22))
        return initial_state, initial_state, vec(non_lin=non_lin, forcing=forcing)

    def forward_process(self, initial_state, guess: Tensor, steps=None) -> Any:
        non_lin = .25 * math.tanh(guess['non_lin']) + .5
        forcing = .1 * math.tanh(guess['forcing'])
        final_state = iterate(kuramoto_sivashinsky, steps or self._steps, initial_state, forcing=forcing, non_lin=non_lin, dt=.5)
        return final_state

    def get_observations(self, output) -> Tensor:
        return output.values

    def loss_function(self, ref_output, guess_output) -> Tensor:
        return math.l2_loss(ref_output - guess_output)

    def create_inverse_net(self, lib) -> Any:
        return conv_classifier(1, (128,), 2, blocks=[32, 32, 64, 64], block_sizes=[1, 1, 1, 1], dense_layers=[64, 64], softmax=False)

    def create_forward_net(self, lib) -> Any:
        return u_net(3, 1, in_spatial=1)

    def run_forward_net(self, net, initial_state, guess: Tensor) -> Any:
        net_input = stack([initial_state.values, *guess.vector], channel('features'), expand_values=True)
        pred_output = math.native_call(net, net_input)
        return initial_state.with_values(pred_output)

    def plot_process(self, path: str, initial_state, guess):
        compare_dims = channel(guess).without('vector')
        guess = rename_dims(guess, compare_dims, batch)
        trj = self.forward_process(initial_state, guess, steps=spatial(time=self._steps))
        vis.show(trj, row_dims=compare_dims, overlay=None, size=(6, 6))
        vis.savefig(path+".jpg", close=True)

    def get_plot(self, plot_type: str, initial_state, guess: Tensor, ref_output) -> dict:
        trj = self.forward_process(initial_state, guess, steps=spatial(time=self._steps))
        trj = trj.with_bounds(trj.bounds.vector['time,x'])
        return dict(obj=trj)


if __name__ == '__main__':
    print(backend.default_backend().get_default_device())
    OPTIMIZER = vec(batch('optimizer'), 'adam')
    SIZE = vec(batch('dataset_size'), 2, 4, 8, 16, 32, 64, 128, 256)
    SIZE_P = vec(batch('dataset_size'), 4, 8, 32, 128)  # Paper
    SEED = vec(batch('seed'), 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
    LEARNING_RATE = vec(batch('lr'), 1e-3)
    ITERATIONS = vec(batch('iterations'), 1000)
    EXPERIMENT = KSForced(steps=50)
    # --- Optimize & plot ---
    visualize_dataset(SIZE, SEED, experiment=EXPERIMENT)
    train_network(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    optimize_directly('BFGS', SIZE, SEED, experiment=EXPERIMENT)
    print(f"Time to train: {get_time_to_train('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT):summary}")
    train_supervised(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    train_surrogate(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    neural_adjoint(SIZE, SEED, LEARNING_RATE, OPTIMIZER, experiment=EXPERIMENT, nets=nets, use_boundary_loss=True)
    visualize_results(SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, nets=nets)
    visualize_optimizations('BFGS,sup,surrogate,net', 'loss', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    visualize_param_trj('BFGS,sup,surrogate,net', 'non_lin,forcing', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('net', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('sup', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('surrogate', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    visualize_optimizations('BFGS,sup,surrogate,net', 'loss', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    plot_by_dset_size('BFGS,net,sup,surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    direct_compare('BFGS,net,surrogate,sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined=None)
    direct_compare('BFGS,net,surrogate,sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    paper_plot('setup, landscape:non_lin, loss-curves, refined-loss-by-n', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, example=2, size=(10, 2.5), param_range=(-3, .5))
    plt.gcf().axes[1].set_xlabel("$\\beta$")
    show()
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_ks.pdf")
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_ks.jpg", close=True)

    plot_by_dset_size('BFGS,net,sup,surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', threshold=20)
    plot_parameter_trajectories('BFGS,net,sup,surrogate', 128, 0, LEARNING_RATE, experiment=EXPERIMENT)
    plot_all_curves('BFGS,net,sup,surrogate', 'loss', SIZE_P, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', extend_curves=False, show_std=True)
