import torch as tc
import numpy as np
from torch import Tensor
from numpy import ndarray
from torch import pi as PI
from torch.autograd import grad
from tqdm import tqdm

from typing import override
from dataclasses import dataclass

from src.utils import flatten, flatten_TC, unflatten, unflatten_TC, symp_euler_step, stormer_verlet_step, runge_kutta_step, animate_2D
from .nbody_system import NBodySystem

@dataclass
class Pendulum():
    link_length: float = 1.0
    mass: float = 1.0
    gravity: float = 1.0
    origin: float = 10.0

    def wrap_angles(self, theta: Tensor) -> Tensor:
        if theta.is_neg():
            return theta % (-PI)
        return theta % PI

    def relative_to_absolute_angles(self, q: Tensor) -> Tensor:
        """
        Args:
            q       relative angles of shape (B, N, 1)
        Returns:
            theta   absolute angles of shape (B, N, 1)
        """
        # theta = self.wrap_angles(q).cumsum(dim=1)
        # return self.wrap_angles(theta)
        theta = q.cumsum(dim=1)
        return theta

    def angles_to_x(self, theta: Tensor) -> Tensor:
        """
        Args:
            theta   absolute angles of shape (B, N, 1)
        Returns:
            x       cartesian x-coordinate of shape (B, N, 1)
        """
        return (self.link_length * theta.sin()).cumsum(dim=1)

    def angles_to_x_dot(self, theta: Tensor, theta_dot: Tensor) -> Tensor:
        """
        Args:
            theta       absolute angles of shape (B, N, 1)
            theta_dot   angular velocities of shape (B, N, 1)
        Returns:
            x_dot   cartesian x-coordinate velocity of shape (B, N, 1)
        """
        return (self.link_length * theta.cos() * theta_dot).cumsum(dim=1)

    def angles_to_y(self, theta: Tensor) -> Tensor:
        """
        Args:
            theta   absolute angles of shape (B, N, 1)
        Returns:
            y       cartesian y-coordinate of shape (B, N, 1)
        """
        return self.origin - (self.link_length * theta.cos()).cumsum(dim=1)

    def angles_to_y_dot(self, theta: Tensor, theta_dot: Tensor) -> Tensor:
        """
        Args:
            theta       absolute angles of shape (B, N, 1)
            theta_dot   angular velocities of shape (B, N, 1)
        Returns:
            y_dot   cartesian y-coordinate velocity of shape (B, N, 1)
        """
        return (self.link_length * theta.sin() * theta_dot).cumsum(dim=1)

    # --- kinetic energy double check

    def kinetic(self, q: Tensor, q_dot: Tensor) -> Tensor:
        """
        Args:
            q           relative angles of shape (B, N, 1)
            q_dot       angular velocities of shape (B, N, 1)
        Returns:
            kinetic     energy of shape (B, 1)
        """
        theta = self.relative_to_absolute_angles(q)
        x_dot = self.angles_to_x_dot(theta, q_dot)
        y_dot = self.angles_to_y_dot(theta, q_dot)
        return 0.5 * self.mass * (x_dot**2 + y_dot**2).sum(dim=1)

    def kinetic_vector(self, q: Tensor, q_dot: Tensor) -> Tensor:
        _, n_obj, _ = q.shape
        theta = self.relative_to_absolute_angles(q)
        val = tc.scalar_tensor(0.0, dtype=q.dtype, device=q.device)
        for i in range(n_obj):
            v = tc.scalar_tensor(0.0, dtype=q.dtype, device=q.device)
            for k in range(i + 1):
                for j in range(i + 1):
                    v = v + self.link_length**2 * (theta[:, k, :] - theta[:, j, :]).cos() * q_dot[:, k, :] * q_dot[:, j, :]
            val = val + self.mass * v
        return 0.5 * val

    def prepare_d_matrix(self, theta: Tensor) -> Tensor:
        # theta of shape (B, N, 1)
        _, n_obj, _ = theta.shape
        M_down = self.mass * tc.arange(n_obj, 0, -1, device=theta.device, dtype=theta.dtype)
        theta_k = theta.squeeze(-1).unsqueeze(2)      # (B, N, 1)
        theta_j = theta.squeeze(-1).unsqueeze(1)      # (B, 1, N)
        cosd = (theta_k - theta_j).cos()  # (B, N, N)
        idx_k = tc.arange(n_obj, device=theta.device).view(n_obj, 1)
        idx_j = tc.arange(n_obj, device=theta.device).view(1, n_obj)
        Mdown_mat = M_down[tc.maximum(idx_k, idx_j)]  # (N, N)
        Mmat = self.mass * self.link_length**2 * cosd * Mdown_mat
        return Mmat

    def kinetic_tensor(self, q: Tensor, q_dot: Tensor) -> Tensor:
        theta = self.relative_to_absolute_angles(q)
        Dq = self.prepare_d_matrix(theta)
        q_dot_T = q_dot.transpose(1, 2)
        ke = q_dot_T.bmm(Dq).bmm(q_dot).squeeze(-1)
        # lhs = tc.bmm(q_dot_T, Dq)
        # ke = tc.bmm(lhs, q_dot).squeeze(-1)
        return 0.5 * ke

    def p_from_qdot(self, q: Tensor, q_dot: Tensor) -> Tensor:
        theta = self.relative_to_absolute_angles(q)
        Dq = self.prepare_d_matrix(theta)
        p = tc.bmm(Dq, q_dot)
        return p

    def qdot_from_p(self, q: Tensor, p: Tensor) -> Tensor:
        # solve Dq * q_dot = p
        theta = self.relative_to_absolute_angles(q)
        Dq = self.prepare_d_matrix(theta)
        q_dot = tc.linalg.solve(Dq, p)
        return q_dot                            # (B, N, 1)

    def kinetic_hamil(self, q: Tensor, p: Tensor) -> Tensor:
        theta = self.relative_to_absolute_angles(q)
        Dq = self.prepare_d_matrix(theta)
        Dq_inv_p = tc.linalg.solve(Dq, p)
        p_T = p.transpose(1, 2)
        ke = p_T.bmm(Dq_inv_p).squeeze(-1)
        return 0.5 * ke

    # --- potential energy double check

    def potential(self, q: Tensor) -> Tensor:
        """
        Args:
            q           relative angles of shape (B, N, 1)
        Returns:
            potential   energy of shape (B, 1)
        """
        theta = self.relative_to_absolute_angles(q)
        y = self.angles_to_y(theta)
        return self.mass * self.gravity * y.sum(dim=1)

    def potential_vector(self, q: Tensor) -> Tensor:
        """
        Args:
            q           relative angles of shape (B, N, 1)
        Returns:
            potential   energy of shape (B, 1)
        """
        _, n_obj, _ = q.shape
        pe = tc.scalar_tensor(0.0, dtype=q.dtype, device=q.device)
        theta = self.relative_to_absolute_angles(q)
        for k in range(n_obj):
            total_mass = (n_obj - k) * self.mass
            pe = pe + total_mass * self.gravity * (self. origin - self.link_length * theta[:, k, :].cos())
        return pe

    def potential_tensor(self, q: Tensor) -> Tensor:
        """
        Args:
            q           relative angles of shape (B, N, 1)
        Returns:
            potential   energy of shape (B, 1)
        """
        _, n_obj, _ = q.shape
        theta = self.relative_to_absolute_angles(q)
        M = self.mass * tc.arange(n_obj, 0, -1, device=q.device, dtype=q.dtype)
        pe = tc.einsum("n,bnd->bd", M, self.origin - self.link_length * theta.cos())
        pe = self.gravity * pe
        return pe
        # return -pe

    def potential_hamil(self, q: Tensor) -> Tensor:
        return self.potential_tensor(q)

    # --- Hamiltonian double check

    def energy(self, q: Tensor, q_dot: Tensor, as_separate=False) -> Tensor | tuple[Tensor, Tensor]:
        p = self.p_from_qdot(q, q_dot)
        kin = self.kinetic_hamil(q, p)
        pot = self.potential_hamil(q)
        # pot = self.potential_hamil(q)
        if as_separate:
            return kin, pot
        return kin + pot

    def dedx(self, q: Tensor, q_dot: Tensor) -> Tensor:
        """
        Args:
            q           relative angles of shape (B, N, 1)
            q_dot       angular velocities of shape (B, N, 1)
        Returns:
            dedx        dedx = [dedq, dedp] of shape (B, N, 2)
        """
        dedq, dedp = self.dedq(q, q_dot), self.dedp(q, q_dot)
        return tc.cat([dedq, dedp], dim=-1)

    def dedq(self, q: Tensor, q_dot: Tensor) -> Tensor:
        p = self.p_from_qdot(q, q_dot)
        q.requires_grad_(True)
        kin = self.kinetic_hamil(q, p)
        pot = self.potential_hamil(q)
        e = kin + pot
        dedq = grad(e.sum(), q)[0]
        q.requires_grad_(False)
        return dedq

    def dedp(self, q: Tensor, q_dot: Tensor) -> Tensor:
        p = self.p_from_qdot(q, q_dot)
        p.requires_grad_(True)
        kin = self.kinetic_hamil(q, p)
        pot = self.potential_hamil(q)
        e = kin + pot
        dedp = grad(e.sum(), p)[0]
        p.requires_grad_(False)
        return dedp

    def dqdt(self, q: Tensor, q_dot: Tensor) -> Tensor:
        dedp = self.dedp(q, q_dot)
        return dedp

    def dpdt(self, q: Tensor, q_dot: Tensor) -> Tensor:
        dedq = self.dedq(q, q_dot)
        return -dedq

    def dxdt(self, q: Tensor, q_dot: Tensor) -> Tensor:
        dqdt, dpdt = self.dqdt(q, q_dot), self.dpdt(q, q_dot)
        return tc.cat([dqdt, dpdt], dim=-1)

    def L(self, n_obj, dof, dtype=np.float64) -> np.ndarray:
        n_features = n_obj * 2*dof
        Ls = np.zeros((n_features, n_features), dtype)
        for i in range(n_obj):
            for j in range(dof):
                Ls[dof*i + j, dof*n_obj + dof*i + j] =  1.
                Ls[dof*n_obj + dof*i + j, dof*i + j] = -1.
        return Ls

    def edge_index(self, n_obj: int) -> Tensor:
        assert n_obj > 1
        edge_index = []
        for i in range(n_obj - 1):
            edge_index.append([i + 1, i]) # towards lower index by default
        edge_index = tc.tensor(edge_index).T
        return edge_index

    def get_x0(self, n_obj, q_lims, qdot_lims, dtype, seed) -> ndarray:
        assert len(n_obj) == 1
        n_obj = n_obj[0]
        rng = np.random.default_rng(seed)
        q_min, q_max = q_lims
        qdot_min, qdot_max = qdot_lims
        q = tc.from_numpy(rng.uniform(q_min, q_max, size=(n_obj, 1)).astype(dtype))
        q_dot = tc.from_numpy(rng.uniform(qdot_min, qdot_max, size=(n_obj, 1)).astype(dtype))
        p = self.p_from_qdot(q.unsqueeze(0), q_dot.unsqueeze(0)).squeeze(0).cpu().numpy()
        return np.concatenate([q, p], axis=-1)

    def integrate(self, x0, len_traj, delta_t, grad_H_, integration_method="stormer_verlet") -> ndarray:
        """
        Args:
            x0                  [q0, p0] ndarray of shape (N, 2D)
            ...
        Returns:
            traj                Trajectory ndarray of shape (len_traj, N, 2D)
        """
        n_obj, n_dim = x0.shape
        dof = n_dim // 2
        x = [x0]
        for _ in tqdm(range(1, len_traj)):
            x_prev = x[-1]

            if integration_method == "symp_euler":
                x_next = symp_euler_step(x_prev, delta_t, grad_H_, None, None, None, None, None, None, n_obj, dof)
            elif integration_method == "stormer_verlet":
                x_next = stormer_verlet_step(x_prev, delta_t, grad_H_, None, None, None, None, None, None, n_obj, dof)
            elif integration_method == "runge_kutta":
                x_next = runge_kutta_step(x_prev, delta_t, grad_H_, None, None, None, None, None, None, n_obj, dof)
            else: raise NotImplementedError(f"integration method {integration_method} is not implemented yet.")
            x.append(x_next)
        return np.stack(x, axis=0)

    def animate(self, traj, framing_length=1, is_cartesian=False):
        """
        Args:
            traj            relative angles or cartesian of shape (n_steps, n_obj, n_dim)
        """
        traj = np.asarray(traj)
        if is_cartesian:
            assert traj.shape[-1] == 2
            q = traj
        else:
            assert traj.shape[-1] == 1
            traj = tc.from_numpy(traj)
            theta = self.relative_to_absolute_angles(traj)
            x = self.angles_to_x(theta)
            y = self.angles_to_y(theta)
            q = tc.cat([x, y], dim=-1).cpu().numpy()

        n_steps, n_obj, n_dim = q.shape
        # q of shape (n_steps, n_obj, n_dim), now append the origin as well at 0
        q_and_origin = np.empty((n_steps, n_obj + 1, n_dim), dtype=q.dtype)
        origin = np.zeros((n_steps, n_dim), dtype=q.dtype)
        origin[..., -1] = self.origin    # here we set only the last axis to be above the ground
        q_and_origin[:, 0, :] = origin
        q_and_origin[:, 1:, :] = q

        print("-> inputting q", q_and_origin.shape)
        print("-> inputting edge_index", self.edge_index(n_obj + 1).cpu().numpy().shape)
        animate_2D(q_and_origin, self.edge_index(n_obj + 1).cpu().numpy(), framing_length=framing_length, filename="pendulum.mp4")
