import numpy as np

from sympy.core.expr import Expr
from sympy.core.symbol import Symbol
from sympy import Matrix, lambdify, diff, integrate


class VectorfieldFuns:
    def __init__(self, scalar_potential: Expr, vector_potential: Expr,
                 w_function=None,
                 periodicity=None,
                 x: Symbol = Symbol('x'),
                 y: Symbol = Symbol('y'),
                 z: Symbol = Symbol('z'),
                 t: Symbol = Symbol('t')):

        self.x = x
        self.y = y
        self.z = z
        self.t = t
        self.tn = Symbol('tn')
        self.speed_multiplier = 1.5

        self._w = w_function
        self._periodicity = periodicity
        if self._w:
            self._altitude = integrate(self._w, (self.t, 0, self.tn))
            self.w = lambdify(self.t, self._w, "numpy")
            self.altitude = lambdify(self.tn, self._altitude, "numpy")
        if self._periodicity:
            self.periodicity = lambdify(self.t, self._periodicity, "numpy")

        self._scalar_potential = scalar_potential
        self._vector_potential = vector_potential

        self._vf_divergence = Matrix([-diff(self._scalar_potential, self.x),
                                      -diff(self._scalar_potential, self.y)])

        self._vf_rotation = Matrix([diff(self._vector_potential, self.y),
                                    -diff(self._vector_potential, self.x)])

        self.__vf_combined = (self._vf_divergence + self._vf_rotation)

        # self._vf_combined = Matrix([-diff(self._scalar_potential, self.x) + diff(self._vector_potential, self.y),
        #                             -diff(self._scalar_potential, self.y) - diff(self._vector_potential, self.x),
        #                             self._w
        #                             ])

        self._divergence = diff(self._vf_divergence[0], self.x) - diff(self._vf_divergence[1], self.y)

        self._rotation = diff(self._vf_rotation[1], self.x) - diff(self._vf_rotation[0], self.y)

        self.__vf_divergence = self.lambdify(self._vf_divergence)
        self.__vf_rotation = self.lambdify(self._vf_rotation)
        self.vf_combined_lambda = self.lambdify(self.__vf_combined)
        self.divergence = self.lambdify(self._divergence)
        self.rotation = self.lambdify(self._rotation)

        self.scalar_potential = self.lambdify(self._scalar_potential)
        self.vector_potential = self.lambdify(self._vector_potential)

    def lambdify(self, expr: Expr):
        if isinstance(expr, Matrix):
            return lambdify((self.x, self.y, self.z), expr.tolist(), "numpy")
        else:
            return lambdify((self.x, self.y, self.z), expr, "numpy")

    def vf_divergence(self, x, y, z=None):
        if z is None:
            z = np.zeros_like(x)
        return np.vstack(np.broadcast(*self.__vf_divergence(x, y, z)))

    def vf_rotation(self, x, y, z=None):
        if z is None:
            z = np.zeros_like(x)
        return np.vstack(np.broadcast(*self.__vf_rotation(x, y, z)))

    def vf_final(self, x, y, t):
        if not np.isscalar(t):
            t = t[..., np.newaxis]
        return self.speed_multiplier * self._vf_combined(x, y) * (np.abs(self.periodicity(t / 2)) + 0.05)

    def _vf_combined(self, x, y, z=None):
        if z is None:
            z = np.zeros_like(x)
        return np.squeeze((np.asarray(self.vf_combined_lambda(x, y, z)))).T

    def altitude(self, t, t_0=0):
        return integrate(self._w, (self.t, t_0, t)).evalf()

    def vf_3d(self, x, y, z, t):
        uv = self.vf_final(x, y, t).reshape(-1, 2)
        w = np.repeat(self.w(t), uv.shape[0])[..., np.newaxis]
        res = np.concatenate([uv, w], -1)
        return res
