from typing import Any, Tuple

import math

from experiment import Experiment, train_network, optimize_directly, visualize_training, visualize_dataset, visualize_optimizations, visualize_results, train_surrogate, \
    neural_adjoint, train_supervised, visualize_param_trj, refine_directly, plot_by_dset_size, paper_plot, get_time_to_train, direct_compare, plot_all_curves, plot_parameter_trajectories
from phi.math import si2d
from phi.torch.flow import *


def grad_nan_to_0(v):
    def backward(fwd_args: dict, _y, dy):
        return {tuple(fwd_args)[0]: math.where(math.is_finite(dy), dy, 0)}
    return math.custom_gradient(math.identity, backward)(v)


def physics_step(v: PointCloud, time: float or Tensor, elasticity=.8, friction=.5):
    """
    Simulate ball

    Args:
        v: State of the system as `PointCloud` containing `Sphere` elements and holding the corresponding velocities as values.
        time: Initial time.
        elasticity: Collision elasticity.
        friction: Unit 1/time.  v(t) = v0 * exp(-k*t).  x(t) = x0 + v0 * (1 - exp(-k*t) / k

    Returns:
        v: Next state as `PointCloud`
        time: Corresponding simulation time.
    """
    other_points = si2d(v.points)
    rel_v = grad_nan_to_0(v.values - si2d(v.values))
    distance = grad_nan_to_0(other_points - v.points)
    t_to_closest = grad_nan_to_0(distance.vector * rel_v.vector / math.vec_squared(rel_v))  # assuming linear velocity, NaN on diagonal
    closest = v.points + t_to_closest * rel_v  # NaN on diagonal, grad=0
    pass_by_distance_squared = math.vec_squared(grad_nan_to_0(closest - other_points))  # will impact if < 2 R, otherwise neg_offset will be NaN
    radius_sum = v.elements.radius + si2d(v.elements.radius)
    impact_offset = math.sqrt(radius_sum ** 2 - pass_by_distance_squared)  # positive, distance by how much the impact happens before the closest point on the line
    impact_time_no_friction = grad_nan_to_0(t_to_closest - impact_offset / math.vec_length(rel_v, eps=1e-5))  # assuming linear velocity, NaN
    impact_time = - math.log(1 - friction * impact_time_no_friction) / friction
    impact_time = math.where(impact_time < 1e-3, NAN, impact_time)
    first_impact_time = math.finite_min(impact_time, default=INF)
    friction_factor = math.exp(- first_impact_time * friction)
    has_impact = impact_time <= first_impact_time + 1e-3  # Handle simultaneous collisions in one go
    impact_relative_position = grad_nan_to_0(other_points - (v.points + impact_time_no_friction * rel_v))
    rel_v_at_impact = rel_v * friction_factor
    impulse = -(1 + elasticity) * .5 * (rel_v_at_impact.vector * impact_relative_position.vector) * impact_relative_position / math.vec_squared(impact_relative_position)
    travel_distance = v.values / friction * (1 - friction_factor)  # passes NaN to v_values
    v = v.with_elements(v.elements.at(v.points + travel_distance))  # Update position
    impact = math.sum(math.where(has_impact, impulse, 0), dual)
    v = v.with_values(v.values * friction_factor + impact)  # Deceleration
    return v, time + first_impact_time


def sample_linear_trajectory(states: PointCloud, times: Tensor, time_dim: math.Shape, velocity_threshold=.1, friction=.5):
    max_velocity = math.max(math.vec_length(states.keys[0].values))
    max_time = math.log(max_velocity / velocity_threshold) / friction
    indices = math.range(batch('keys'), states.keys.size)
    lin_t = math.linspace(0, max_time, time_dim)
    key_i = math.max(math.where(times <= lin_t, indices, -1), 'keys')
    prev_vel = states.values.keys[key_i]
    prev_pos = states.points.keys[key_i]
    prev_time = times.keys[key_i]
    dt = lin_t - prev_time
    friction_factor = math.exp(- dt * friction)
    travel_distance = prev_vel / friction * (1 - friction_factor)
    new_pos = prev_pos + travel_distance
    new_velocities = prev_vel * friction_factor
    return PointCloud(Sphere(new_pos, radius=states.elements.radius[{'keys': 0}]), new_velocities), lin_t


class Billiards(Experiment):

    def __init__(self, ball_radius=.2):
        self._ball_radius = ball_radius

    def __repr__(self):
        return f"{self.__class__.__name__}_r{self._ball_radius}"

    def generate_problem(self, batch_dims: Shape, test: bool) -> Tuple[Any, Any, Tensor]:
        cue_pos = vec(x=0, y=.5)
        ball_pos = vec(x=1, y=math.random_uniform(*batch_dims))
        balls = Sphere(stack([cue_pos, ball_pos], instance(ball='cue,1')), radius=self._ball_radius)
        initial_v = expand(stack({'cue': vec(x=1, y=0), '1': vec(x=0, y=0)}, instance(balls)), batch_dims)
        random_guess = vec(cue_vx=math.random_uniform(batch_dims, low=0, high=3), cue_vy=math.random_uniform(batch_dims, low=-2, high=2))
        # target = expand(vec(x=2, y=.5), batch_dims)
        initial_state = PointCloud(balls, values=initial_v)
        return initial_state, initial_state, random_guess

    def forward_process(self, initial_state: PointCloud, guess: Tensor, max_steps=2) -> PointCloud:
        cue_v0 = vec(x=guess['cue_vx'], y=guess['cue_vy'])
        state0 = initial_state + stack([cue_v0, vec(x=0, y=0)], instance(initial_state))
        v_keys, key_times = iterate(physics_step, batch(keys=max_steps), state0, 0)
        # delta = v_keys.keys[-1].points.ball['1'] - initial_state.points.ball['1']
        return v_keys.keys[-1]

    def get_observations(self, output: PointCloud) -> Tensor:
        return output.points.ball['1']

    def get_desired(self, output) -> Tensor:
        return expand(vec(x=2, y=.5), batch(output))

    def loss_function(self, ref_output: PointCloud, guess_output: PointCloud) -> Tensor:
        return math.l2_loss(self.get_observations(guess_output) - vec(x=2, y=.5))

    def create_inverse_net(self, lib) -> Any:
        return lib.dense_net((2 + 2) * 8, 2, layers=[128, 128, 128], activation='Sigmoid')  # (ball_pos0, target) -> cue_velocity

    def run_inverse_net(self, net, initial_state, observed: Tensor, sol_shape: Shape):
        ball_pos0 = initial_state.points.ball['1']
        unencoded_input = math.concat([ball_pos0, observed], 'vector')
        fourier_k = math.linspace(1, 10, channel(k=4))
        fourier_encoding = math.flatten(math.concat([math.sin(unencoded_input * fourier_k), math.cos(unencoded_input * fourier_k)], 'k'), channel('input'))
        cue_v0 = math.native_call(net, fourier_encoding, spatial_dim=sol_shape.spatial, channel_dim=sol_shape.channel)
        return cue_v0

    def create_forward_net(self, lib) -> Any:
        # (ball_pos0, cue_velocity) -> pos_final
        return lib.dense_net((2 + 2) * 8, 2, layers=[128, 128, 128], activation='Sigmoid')

    def run_forward_net(self, net, initial_state, guess: Tensor) -> Any:
        ball_pos0 = initial_state.points.ball['1']
        unencoded_input = math.concat([ball_pos0, guess], 'vector')
        fourier_k = math.linspace(1, 10, channel(k=4))
        fourier_encoding = math.flatten(math.concat([math.sin(unencoded_input * fourier_k), math.cos(unencoded_input * fourier_k)], 'k'), channel('input'))
        final_ball_pos = math.native_call(net, fourier_encoding, channel_dim=channel(vector='x,y'))
        return PointCloud(stack({'cue': vec(x=0, y=0), '1': final_ball_pos}, instance(initial_state)))

    def plot_process(self, path: str, initial_state, guess, max_steps=2):
        extra_channels = channel(guess).without('vector')
        guess = rename_dims(guess, extra_channels, batch)
        cue_v0 = vec(x=guess['cue_vx'], y=guess['cue_vy'])
        state0 = initial_state + stack([cue_v0, vec(x=0, y=0)], instance(initial_state))
        v_keys, key_times = iterate(physics_step, batch(keys=max_steps), state0, 0)
        v_trj, _ = sample_linear_trajectory(v_keys, key_times, batch(time=50))
        v_trj = rename_dims(v_trj, extra_channels, channel)
        plot(v_trj.elements, vec(x=[2], y=[.5]), overlay='args,overlay', animate='time', title="", same_scale='x,y')
        vis.savefig(path + "sequence.mp4")

    def get_plot(self, plot_type: str, initial_state, guess: Tensor, ref_output) -> dict:
        pos = wrap([(0, .5), (1, 0)], instance(ball='cue,ball'), channel(vector='x,y'))
        spheres = Sphere(pos, radius=self._ball_radius)
        vel = wrap([(1, -.5), (0, 0)], instance(ball='cue,ball'), channel(vector='x,y')) / 2
        col = wrap([5, 6], instance(pos))
        target = Cuboid(expand(vec(x=2, y=.5), instance(t='target')), vec(x=.06, y=.06))
        if plot_type == 'setup':
            return dict(obj=vis.overlay(spheres, PointCloud(pos, vel), target),
                        color=vis.overlay(col, col, 7))
        else:
            raise NotImplementedError(plot_type)


if __name__ == '__main__':
    print(backend.default_backend().get_default_device())
    OPTIMIZER = vec(batch('optimizer'), 'adam')
    # SIZE = vec(batch('dataset_size'), 4, 16, 64, 256)  # shown in paper
    SIZE = vec(batch('dataset_size'), 4, 8, 16, 32, 64, 128, 256)
    SEED = vec(batch('seed'), 0, 1, 2, 3)
    ITERATIONS = vec(batch('iterations'), 2000)
    EXPERIMENT = Billiards()
    # --- Optimize & plot ---
    LEARNING_RATE = vec(batch('lr'), 1e-4)
    visualize_dataset(SIZE, SEED, experiment=EXPERIMENT)
    train_network(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('net', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    optimize_directly('BFGS', SIZE, SEED, experiment=EXPERIMENT)
    print(f"Time to train: {get_time_to_train('net', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT):summary}")

    LEARNING_RATE = vec(batch('lr'), 1e-3)
    train_supervised(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    visualize_training('sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('sup', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)

    train_surrogate(SIZE, SEED, LEARNING_RATE, ITERATIONS, OPTIMIZER, experiment=EXPERIMENT, nets=nets)
    neural_adjoint(SIZE, SEED, LEARNING_RATE, OPTIMIZER, experiment=EXPERIMENT, nets=nets, use_boundary_loss=True)
    visualize_training('surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    refine_directly('surrogate', 'BFGS', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT)
    LEARNING_RATE = dict(net=1e-4, sup=1e-3, surrogate=1e-3)
    visualize_results(SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, nets=nets)
    visualize_param_trj('BFGS,sup,surrogate,net', SIZE, SEED, learning_rate=LEARNING_RATE, experiment=EXPERIMENT)
    visualize_optimizations('BFGS,sup,surrogate,net', 'loss', SIZE, SEED, learning_rate=LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    direct_compare('BFGS,net,surrogate,sup', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs')

    paper_plot('setup, landscape:cue_vy, loss-curves, refined-loss-by-n', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, example=2, size=(10, 2.5), param_range=(-3, .5))
    # plt.gcf().axes[1].set_xlabel("Cue $v_y$")
    show()
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_billiards.pdf")
    vis.savefig(f"~/phi/RP/{EXPERIMENT}/paper_billiards.jpg", close=True)

    SIZE = vec(batch('dataset_size'), 4, 8, 16, 32, 64, 128, 256)
    plot_by_dset_size('BFGS,net,sup,surrogate', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', threshold=20)
    plot_parameter_trajectories('BFGS,net,sup,surrogate', 128, 0, LEARNING_RATE, experiment=EXPERIMENT)
    SIZE = vec(batch('dataset_size'), 4, 16, 64, 256)  # shown in paper
    plot_all_curves('BFGS,net,sup,surrogate', 'loss', SIZE, SEED, LEARNING_RATE, experiment=EXPERIMENT, refined='bfgs', extend_curves=False, show_std=True)
