from typing import Any, Tuple

from matplotlib import pyplot as plt

from experiment import Experiment, train_network, optimize_directly, visualize_training, visualize_dataset, train_supervised, train_surrogate, neural_adjoint, plot_by_dset_size, refine_directly, visualize_param_trj, paper_plot, \
    get_time_to_train, direct_compare, plot_parameter_trajectories, plot_all_curves, plot_all_optimizations
from phi.torch.flow import *


class Fluid(Experiment):

    def __init__(self, resolution: Shape, dt=1, steps=16, background_str=.5):
        self.resolution = resolution
        self.dt = dt
        self.steps = steps
        self.background_str = background_str
        self.obstacles = Obstacle(union(Box(x=(40, 60), y=(80, 90)), Box(x=(70, 80), y=(50, 60)), Box(x=(20, 30), y=(50, 60))))

    def __repr__(self):
        return f"{self.__class__.__name__}_{'x'.join([str(s) for s in self.resolution.sizes])}_t{self.steps}_bkg{self.background_str}"

    def generate_problem(self, batch_dims: Shape, test: bool) -> Tuple[Any, Any, Tensor]:
        v0 = StaggeredGrid(Noise(batch_dims), {'x': 0, 'y': (ZERO_GRADIENT, 0)}, Box(x=100, y=100), self.resolution) * self.background_str
        v0, p0 = fluid.make_incompressible(v0, self.obstacles)  # to get initial pressure
        pos_x = math.random_uniform(batch_dims, low=-1, high=1)
        strength_x = math.random_uniform(batch_dims, low=-8, high=8)
        strength_y = math.random_uniform(batch_dims, low=0, high=12)
        v0_masked = v0 * v0.with_values(lambda x, y: y > 50)
        return (v0, p0), (v0_masked, p0), vec(pos_x=pos_x, strength_x=strength_x, strength_y=strength_y)

    def forward_process(self, initial_state, guess: Tensor, steps=None) -> Tuple[Any, Tensor]:
        v0, p0 = initial_state
        fan_pos = vec(x=math.clip(guess['pos_x']*15 + 50, 35, 65), y=25)
        fan_v = vec(x=guess['strength_x'], y=guess['strength_y'])
        v0 += field.resample(Sphere(fan_pos, radius=18), to=v0, soft=True) * fan_v
        v0 = v0.with_values(filter_nan_grad(v0.values))
        out_v, out_p = iterate(fluid_step, steps or self.steps, v0, p0, obstacles=self.obstacles, dt=self.dt)
        return out_v.at_centers().values

    def get_observations(self, output) -> Tensor:
        return output.y[self.resolution.get_size('y') // 2:]

    def create_inverse_net(self, lib) -> Any:
        return conv_classifier(2, self.resolution.with_dim_size('y', self.resolution.get_size('y') // 2).sizes, 3, [16, 32, 32, 32], [1, 1, 1, 1], [64, 64], softmax=False)

    def create_forward_net(self, lib) -> Any:
        return u_net(5, 2, activation='Sigmoid')

    def run_forward_net(self, net, initial_state, guess: Tensor) -> Tuple[Any, Tuple[Tensor]]:
        v0, p0 = initial_state
        v0_data = self.get_observations(v0.at_centers().values)
        upper_half = math.native_call(net, concat([v0_data, expand(guess, spatial(v0_data))], 'vector'))
        return concat([upper_half * 0, upper_half], 'y')

    def plot_process(self, path: str, initial_state, guess):
        v_trj = self.forward_process(initial_state, guess, steps=batch(time=self.steps))
        vis.plot(v_trj, animate='time')
        vis.savefig(path + ".mp4")

    def get_plot(self, plot_type: str, initial_state, guess: Tensor, ref_output) -> dict:
        v0, p0 = initial_state
        if plot_type == 'initial':
            fan_pos = vec(x=math.clip(guess['pos_x'] * 15 + 50, 35, 65), y=25)
            fan_v = vec(x=guess['strength_x'], y=guess['strength_y'])
            v0 += field.resample(Sphere(fan_pos, radius=18), to=v0, soft=True) * fan_v
            v1, _ = fluid_step(v0, p0, obstacles=self.obstacles, dt=self.dt)
            curl = field.curl(v1)
            return dict(obj=vis.overlay(curl, self.obstacles.geometry), color=vis.overlay(0, '#909090'))
        elif plot_type == 'final':
            final = v0.with_values(ref_output)
            curl = field.pad(field.curl(final), -1)
            return dict(obj=vis.overlay(curl, self.obstacles.geometry), color=vis.overlay(0, '#909090'))


@jit_compile(auxiliary_args='obstacles')
def fluid_step(v, p, obstacles, dt=1.):
    v = advect.semi_lagrangian(v, v, dt)
    v, p = fluid.make_incompressible(v, obstacles, Solve(x0=p, suppress=[NotConverged, Diverged], gradient_solve=Solve(x0=p*0, abs_tol=1e-3, rel_tol=1e-3, suppress=[NotConverged, Diverged])))
    return v, p


def filter_nan_grad(v):
    def backward(fwd_args: dict, _y, dy):
        return {tuple(fwd_args)[0]: math.where(math.is_finite(dy), dy, 0)}
    return math.custom_gradient(math.identity, backward)(v)


if __name__ == '__main__':
    OPTIMIZER = vec(batch('optimizer'), 'adam')
    SIZE = vec(batch('dataset_size'), 4, 8, 16, 32, 64, 128)
    # SIZE = vec(batch('dataset_size'), 4, 16, 64, 128)  # shown in paper
    SEED = vec(batch('seed'), 0, 1, 2, 3)
    LEARNING_RATE = vec(batch('lr'), 1e-3)
    ITERATIONS = vec(batch('iterations'), 3000)
    EXPERIMENT = Fluid(spatial(x=64, y=64), steps=56, background_str=.5)
    # --- Optimize & plot ---
    visualize_dataset(SIZE, SEED, experiment=EXPERIMENT)
    train_network(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('net', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    optimize_directly('BFGS', SIZE, SEED, experiment=EXPERIMENT)
    print(f"Time to train: {get_time_to_train('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT):summary}")
    train_supervised(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('sup', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    train_surrogate(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    neural_adjoint(SIZE, SEED, LEARNING_RATE, OPTIMIZER, experiment=EXPERIMENT, nets=nets, use_boundary_loss=True)
    visualize_training('surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('surrogate', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    visualize_param_trj('BFGS,sup,surrogate,net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    plot_all_optimizations('BFGS,net,sup,surrogate', 'loss', 128, SEED, 1e-3, experiment=EXPERIMENT, refined='bfgs')

    direct_compare('BFGS,net,surrogate,sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    plot_by_dset_size('BFGS,net,sup,surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', threshold=20)
    plot_parameter_trajectories('BFGS,net,sup,surrogate', 128, 0, LEARNING_RATE, experiment=EXPERIMENT)
    plot_all_curves('BFGS,net,sup,surrogate', 'loss', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    paper_plot('initial, final, loss-curves, refined-loss-by-n', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, example=9, size=(10, 2.5), curves_batch_size=64)
    plt.gcf().axes[3].set_ylim((0, 130))
    show()
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_fluid.pdf")
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_fluid.jpg", close=True)
