from typing import Any, Tuple

import matplotlib.pyplot as plt

from experiment import Experiment, train_network, optimize_directly, visualize_training, visualize_dataset, visualize_optimizations, visualize_results, random_guessing, plot_by_dset_size, visualize_param_trj, refine_directly, neural_adjoint, train_surrogate, train_supervised, paper_plot, \
    get_time_to_train, direct_compare, plot_parameter_trajectories, plot_all_curves
from phi.torch.flow import *


class WavePacket(Experiment):

    def __init__(self, noise_level: float, resolution: Shape):
        self._noise_level = noise_level
        self._resolution = resolution
        self._x = math.range(self._resolution)

    def __repr__(self):
        return f"{self.__class__.__name__}_{'x'.join(str(s) for s in self._resolution.sizes)}_noise{self._noise_level:.2f}"

    def generate_problem(self, batch_dims: Shape, test: bool) -> Tuple[Any, Any, Tensor]:
        x_gt = math.random_uniform(batch_dims)
        noise = self._noise_level * math.random_normal(self._resolution)
        return noise, math.zeros_like(noise), vec(pos=x_gt)

    def forward_process(self, initial_state, guess: Tensor) -> Tuple[Any, Tensor]:
        x0 = (0.1 + 0.8 * guess['pos']) * self._resolution.size
        signal = initial_state + math.sin((self._x-x0) * 0.7) * math.exp(- (self._x-x0) ** 2 / 20 ** 2)
        return signal

    def get_observations(self, output) -> Tensor:
        return output

    def loss_function(self, ref_output, guess_output) -> Tensor:
        return math.l2_loss(ref_output - guess_output)

    def create_inverse_net(self, lib) -> Any:
        return conv_classifier(1, self._resolution.sizes, 1, (16, 16, 16, 16, 16), (1, 1, 1, 1, 1), dense_layers=[64, 32, 1], softmax=False, activation='ReLU', batch_norm=True)

    def create_forward_net(self, lib) -> Any:
        return u_net(2, 1, in_spatial=1)

    def run_forward_net(self, net, initial_state, guess: Tensor) -> Any:
        x = concat([initial_state, guess], 'vector', expand_values=True)
        return math.native_call(net, x)

    def plot_process(self, path: str, initial_state, guess):
        forward = math.map_types(self.forward_process, channel(guess).without('vector'), batch)
        signal = forward(initial_state, guess)
        from phi.vis._vis_base import index_label
        title = math.sum([-f-f"{index_label(idx)}={guess.vector['pos'][idx]:.2f} " for idx in channel(guess).without('vector').meshgrid(names=True)], "0")
        vis.show(signal, title=title)
        vis.savefig(path + "_signal.jpg")

    def get_plot(self, plot_type: str, initial_state, guess: Tensor, ref_output) -> dict:
        if plot_type == 'waveform':
            return dict(obj=expand(ref_output, channel(c=['signal'])), color='#919191')
        else:
            raise NotImplementedError


if __name__ == '__main__':
    OPTIMIZER = vec(batch('optimizer'), 'adam')
    # SIZE = vec(batch('dataset_size'), 2, 8, 32, 128)  # shown in paper
    SIZE = vec(batch('dataset_size'), 2, 4, 8, 16, 32, 64, 128, 256)
    SEED = vec(batch('seed'), 0, 1, 2, 3, 4)
    LEARNING_RATE = vec(batch('lr'), 1e-2)
    ITERATIONS = vec(batch('iterations'), 5000)
    EXPERIMENT = WavePacket(noise_level=0.1, resolution=spatial(x=256))
    # --- Optimize & plot ---
    random_guessing(SIZE, SEED, experiment=EXPERIMENT)
    visualize_dataset(SIZE, SEED, experiment=EXPERIMENT)
    train_network(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('net', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    optimize_directly('BFGS', SIZE, SEED, experiment=EXPERIMENT)
    print(f"Time to train: {get_time_to_train('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT):summary}")
    train_supervised(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('sup', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    train_surrogate(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    neural_adjoint(SIZE, SEED, LEARNING_RATE, OPTIMIZER, experiment=EXPERIMENT, nets=nets, use_boundary_loss=True)
    visualize_training('surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('surrogate', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    visualize_results(SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, nets=nets)
    visualize_param_trj('BFGS,sup,surrogate,net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('guessing', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    visualize_optimizations('BFGS,guessing,sup,surrogate,net', 'loss', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    plot_by_dset_size('BFGS,sup,surrogate,net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', threshold=10)

    direct_compare('BFGS,net,surrogate,sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    SIZE = vec(batch('dataset_size'), 2, 4, 8, 16, 32, 64, 128, 256)
    plot_by_dset_size('BFGS,net,sup,surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', threshold=20)
    plot_parameter_trajectories('BFGS,net,sup,surrogate', 128, 0, LEARNING_RATE, experiment=EXPERIMENT)
    SIZE = vec(batch('dataset_size'), 2, 8, 32, 128)  # shown in paper
    plot_all_curves('BFGS,net,sup,surrogate', 'loss', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', extend_curves=True, show_std=False)

    paper_plot('waveform, landscape:pos, loss-curves, refined-loss-by-n', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, example=0, size=(10, 2.5))
    plt.gcf().axes[1].set_xlabel("$t_0$")
    show()
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_wavepacket.pdf")
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_wavepacket.jpg", close=True)
