import numpy as np

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from src.utils import hamiltons_eq

@dataclass
class NBodySystem(ABC):
    """
    Represents data for N-body systems
    """
    n_points: int       # number of samples of the system
    n_features: int     # total number of features for each position and momentum
    q: np.ndarray       # positions of shape (n_points, n_features)
    p: np.ndarray       # momenta of shape (n_points, n_features)

    def __post_init__(self):
        assert self.q.shape == (self.n_points, self.n_features//2), f"q shape wrong, got {self.q.shape}"
        assert self.p.shape == (self.n_points, self.n_features//2), f"p shape wrong, got {self.p.shape}"
        assert self.q.dtype == self.p.dtype
        self.dtype = self.q.dtype

    @abstractmethod
    def to_array(self, flatten=False) -> np.ndarray:
        """
        Converts data into system specific data of desired shape instead of shape (n_points, n_features)
        This may depend, e.g., for MassSpring this function should return something like (n_points, *n_obj, 2*dof)
        Returns:
            x of shape (n_points, ..., n_features) or flattened (n_points, -1)
        """
        pass

    @abstractmethod
    def from_array(self, x: np.ndarray):
        """
        Inverse operation of 'to_array' which saves the data inside, shapes can be system specific.
        """
        pass

    @abstractmethod
    def H(self, as_local=False, as_separate=False) -> Any:
        pass

    @abstractmethod
    def grad_H(self, flatten=False, noise_scale=0.0, rng=None) -> Any: pass

    @abstractmethod
    def L(self) -> np.ndarray:
        """
        Should return corresponding Poisson matrix for the Hamiltonian system.
        """
        pass

    @abstractmethod
    def flatten(self, x) -> np.ndarray:
        """
        Should implement a flattened array for inputting to networks as networks usually expects
        flattened feature dimensions. This provides flexibility to classes implementing this because
        they can choose how they want to stack the features of the classes
        """
        pass

    @abstractmethod
    def unflatten(self, x) -> np.ndarray:
        """
        Inverse operation of self.flatten
        """
        pass

    def dxdt(self, flatten=False, noise_scale=0.0, rng=None):
        """
        noise_scale is only applied when gathering dxdt to the x values directly,
        this simulates e.g. sensors picking up different values for different x then the intended one
        """
        dedx = self.grad_H(noise_scale=noise_scale, rng=rng) # of shape (n_points, *n_obj, 2*dof)
        dxdt = hamiltons_eq(dedx)                   # of shape (n_points, *n_obj, 2*dof)
        if flatten: dxdt = self.flatten(dxdt)       # of shape (n_points, np.prod(n_obj) * 2*dof) = (n_points, n_features)
        return dxdt                                 # of shape (n_points, *n_obj, 2*dof)
