"""
Utilities for generating vector fields.
"""
import numpy as np

import scipy.integrate
from sympy import *
from sympy.core.numbers import Zero

from matplotlib.colors import colorConverter
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from pdPINN.util.data_containers import VectorfieldFuns


def generate_circle_vectorfield() -> VectorfieldFuns:
    """
    Generate symbolic functions for the vector fields used for visualization purposes
    :param with_time:
    :return:
    """
    x, y = symbols('x y')
    scalar_potential = -x / 3  ## (1 / 30) * (x+5) ** 2 - x/2 +y**2/40# - exp(-(y/2.5)**2)
    vector_potential = exp(-((x / 1.5) ** 2 + (y / 1.5) ** 2))
    vf_function = VectorfieldFuns(scalar_potential=scalar_potential,
                                  vector_potential=vector_potential,
                                  x=x, y=y)

    return vf_function


def generate_vf_functions(name: str, landing_periodicity, rotation_only=False,
                          divergence_only=False) -> VectorfieldFuns:
    """
    Generate symbolic functions for the vector fields
    :param with_time:
    :return:
    """
    assert not (rotation_only and divergence_only)

    t = symbols('t')
    periodicity = sin(landing_periodicity * t * (2 * pi))
    w_function = 1.6 * periodicity  # **3

    if name == "inward_spiral":
        x, y = symbols('x y')
        scalar_potential = -exp(-((x / 1.5) ** 2 + (y / 1.5) ** 2))
        vector_potential = exp(-((x / 1.5) ** 2 + (y / 1.5) ** 2))
        vf_function = VectorfieldFuns(scalar_potential=.5 * scalar_potential * int(not rotation_only),
                                      vector_potential=3. * vector_potential * int(not divergence_only),
                                      w_function=w_function,
                                      periodicity=periodicity,
                                      x=x, y=y)
    elif name == "landing":
        x, y = symbols('x y')
        # scalar_potential = -exp(-((x / 1.5) ** 2 + (y / 1.5) ** 2))
        scalar_potential = -(x - 2) * (
                    y - 2)  # (exp(-(((x - 1.) / 1.5) ** 2 + ((y - 1.) / 1.5) ** 2)) - .6*(x-.3)*(y-.3))
        vector_potential = -  exp(-((x / 1.5) ** 2 + (y / 1.5) ** 2))
        # w_altitude =
        vf_function = VectorfieldFuns(
            # scalar_potential=.25 * scalar_potential * int(not rotation_only),
            #                           vector_potential=.75 * vector_potential * int(not divergence_only),
            scalar_potential=.5 * scalar_potential * int(not rotation_only),
            vector_potential=2. * vector_potential * int(not divergence_only),
            w_function=w_function,
            periodicity=periodicity,
            x=x, y=y)
    else:
        # scalar_potential = (1 / ((.5 * (x + 1) ** 2) + 1e-1))  # ** 3  # - exp(-(y/2.5)**2)
        # scalar_potential = -2.5*exp(-(((4*(x-1.2)) / 7.) ** 2 ))
        # scalar_potential += .5*exp(-(((4*(y-1.2)) / 7.) ** 2 ))
        # scalar_potential = -exp(-((x / 2.5 +1.5) ** 4 + ((y / 2.5)+ 0.5) ** 4)) #+ exp(-((x / 2.5+1.5) ** 4 + ((y / 2.5) - 0.5) ** 4))
        # scalar_potential = 1 * sin(1. / 4 * x * y) + y
        # scalar_potential = 1 * sin(1. / 4 * x * y) + cos(1. / 4 * x * y)
        # v_pot1 = -exp(-(((4*x) / 2.5) ** 4 + (((4*y) / 2.5) + 0.5) ** 4))
        # v_pot2 = exp(-(((4*x) / 2.5) ** 4 + (((4*y) / 2.5) - 0.5) ** 4))
        # vector_potential = 2. * (v_pot1 + v_pot2)
        raise NotImplementedError(f"unknown vf name: '{name}'.")

    return vf_function


def generate_vf_functions_time(radius, time_step):
    """
    Generate symbolic functions for the vector fields
    :param with_time:
    :return:
    """
    x, y, t, r, phi, t_delta = symbols('x y t r phi t_delta')
    theta = t * pi / 4
    rotated_x = (x * -cos(theta))  # - y * sin(theta))
    # rotated_y = -(x * sin(theta) + y * cos(theta))

    f_s_pot = -(1 / 20) * x ** 3 - x / 3

    f_s_pot = f_s_pot.subs(x, rotated_x)

    v_pot1 = -exp(-((x / 2.5) ** 4 + ((y / 2.5) + 0.5) ** 4))
    v_pot2 = exp(-((x / 2.5) ** 4 + ((y / 2.5) - 0.5) ** 4))
    # f_s_pot = f_s_pot+ (v_pot1 + v_pot2)

    vf_div = Matrix([-diff(f_s_pot, x), -diff(f_s_pot, y)])
    # vf_comb = vf_div#+ vf_rot

    f_v_pot = (v_pot1 + v_pot2)
    vf_rot = Matrix([diff(f_v_pot, y), -diff(f_v_pot, x)])
    f_rot = diff(vf_rot[1], x) - diff(vf_rot[0], y)
    f_rot = simplify(f_rot)
    vf_comb = vf_div + vf_rot

    f_div = diff(vf_comb[0], x) + diff(vf_comb[1], y)
    f_div = simplify(f_div)

    x_delta = r * cos(phi)
    y_delta = r * sin(phi)
    f_div_subst = f_div.subs(x, x + x_delta).subs(y, y + y_delta).subs(t, t + t_delta)
    f_flux = integrate(f_div_subst * r, (phi, 0, 2 * pi), (r, 0, radius), (t_delta, 0, time_step))

    f_flux_np = lambdify((t, x, y), f_flux, "numpy")
    f_div_np = lambdify((t, x, y), f_div, "numpy")
    f_rot_np = lambdify((t, x, y), f_rot, "numpy")

    f_s_pot_np = lambdify((t, x, y), f_s_pot, "numpy")
    f_v_pot_np = lambdify((t, x, y), f_v_pot, "numpy")
    vf_div_np = lambdify((t, x, y), vf_div.tolist(), "numpy")
    vf_rot_np = lambdify((t, x, y), vf_rot.tolist(), "numpy")
    vf_comb_np = lambdify((t, x, y), vf_comb.tolist(), "numpy")
    return f_s_pot_np, f_v_pot_np, vf_comb_np, vf_div_np, vf_rot_np, f_div_np, f_rot_np, f_flux_np
