import numpy as np
import torch as tc

from typing import override
from dataclasses import dataclass

from src.utils import Mesh, flatten, unflatten, symp_euler_step, stormer_verlet_step, runge_kutta_step
from .nbody_system import NBodySystem

@dataclass
class MassSpring(NBodySystem):
    n_obj: list[int]    # number of objects in x,y,... dimensions
    dof: int            # degrees of freedom

    # properties of the system (for each object)
    mass: float = 1.0
    spring_constant: float = 1.0
    meshing: Mesh = Mesh("rectangular")

    def __post_init__(self):
        super().__post_init__()
        if self.is_chain():
            assert self.meshing.mesh_type == "rectangular", \
                   "For chain only meshing 'rectangular' is valid"
        assert self.is_chain() or self.is_lattice() or self.is_ball(), \
               "Number of objects can be specified on 3 axes maximum."
        assert self.dof <= 3, "DOF of each object can be maximum 3."
        assert self.mass > 0.0, "Mass cannot be negative"
        assert self.spring_constant > 0.0, "Spring constant cannot be negative"

    def L(self) -> np.ndarray:
        Ls = np.zeros((self.n_features, self.n_features), self.dtype)
        total_num_objects = np.prod(self.n_obj)
        for i in range(total_num_objects):
            for j in range(self.dof):
                Ls[self.dof*i + j, self.dof*total_num_objects + self.dof*i + j] =  1.
                Ls[self.dof*total_num_objects + self.dof*i + j, self.dof*i + j] = -1.
        return Ls

    def flatten(self, x) -> np.ndarray:
        return flatten(x)

    def unflatten(self, x) -> np.ndarray:
        return unflatten(x, self.n_obj, self.dof)

    @override
    def to_array(self, flatten=False) -> np.ndarray:
        """
        Returns array representation of the system of shape (n_points, *n_obj, 2*dof)
        """
        assert self.q.shape == (self.n_points, np.prod(self.n_obj) * self.dof)
        assert self.p.shape == (self.n_points, np.prod(self.n_obj) * self.dof)

        x = np.concatenate([self.q, self.p], axis=-1)
        if not flatten: x = self.unflatten(x)
        return x

    @override
    def from_array(self, x: np.ndarray):
        assert x.shape == (self.n_points, *self.n_obj, 2*self.dof), \
               f"Expected shape {(self.n_points, *self.n_obj, 2*self.dof)} got {x.shape}"

        q, p = np.split(x, 2, axis=-1) # shape (n_points, *n_obj, dof) each
        self.q = q.reshape(self.n_points, np.prod(self.n_obj) * self.dof)
        self.p = p.reshape(self.n_points, np.prod(self.n_obj) * self.dof)

    @override
    def H(self, as_local=False, as_separate=False):
        if self.is_chain(): system = "chain"
        elif self.is_lattice(): system = "lattice"
        else: raise NotImplementedError("System is not supported")
        return H_mass_spring(self.to_array(), self.n_obj, self.meshing, self.mass, self.spring_constant,
                             system, as_local, as_separate)

    @override
    def grad_H(self, flatten=False, noise_scale=0.0, rng=None):
        if self.is_chain(): system = "chain"
        elif self.is_lattice(): system = "lattice"
        else: raise NotImplementedError("System is not supported")

        if not rng is None:
            noise = rng.normal(loc=0.0, scale=noise_scale, size=(self.n_points, *self.n_obj, 2*self.dof)).astype(self.dtype)
        else:
            noise = 0.0

        dedx = grad_H_mass_spring(self.to_array() + noise, self.dof, self.meshing,
                                  self.mass, self.spring_constant, system)
        if flatten: dedx = self.flatten(dedx)
        return dedx

    def is_chain(self) -> bool:
        """
        Whether the system is coupled along a single dimension.
        """
        return len(self.n_obj) == 1

    def is_lattice(self) -> bool:
        """
        Whether the system is coupled along two dimensions.
        """
        return len(self.n_obj) == 2

    def is_ball(self) -> bool:
        """
        Whethe the system is coupled along three dimensions.
        """
        return len(self.n_obj) == 3

    def integrate(self, len_traj, delta_t, grad_H_, fixed_indices=None, gravity=None,
                  force_node_idx=None, force_vec=None, force_snap_idx=None, integration_method="stormer_verlet"):
        """
        Symplectic integration of self.q and self.p using specified arguments and the first sample
        in self.q and self.p (initial condition). The function does not return anything but stores
        the resulting trajectory in self.q and self.p.
        """
        assert len(self.q) == len_traj and len(self.p) == len_traj
        for idx_snap in range(1, len_traj):
            if idx_snap % max(len_traj // 10, 1) == 0:
                print(f"-> Integrating: {idx_snap}/{len_traj}")
            x = self.to_array()                                     # of shape (len_traj, *n_obj, 2*dof)
            x_prev = x[idx_snap - 1]                                # of shape (*n_obj, 2*dof)

            if integration_method == "symp_euler":
                x_next = symp_euler_step(x_prev, delta_t, grad_H_, gravity, force_node_idx, force_vec, force_snap_idx, idx_snap, fixed_indices, self.n_obj, self.dof)
            elif integration_method == "stormer_verlet":
                x_next = stormer_verlet_step(x_prev, delta_t, grad_H_, gravity, force_node_idx, force_vec, force_snap_idx, idx_snap, fixed_indices, self.n_obj, self.dof)
            elif integration_method == "runge_kutta":
                x_next = runge_kutta_step(x_prev, delta_t, grad_H_, gravity, force_node_idx, force_vec, force_snap_idx, idx_snap, fixed_indices, self.n_obj, self.dof)
            else: raise NotImplementedError(f"integration method {integration_method} is not implemented yet.")
            x[idx_snap] = x_next
            self.from_array(x)

    def edge_index(self):
        if self.is_chain() or self.is_lattice():
            return edge_index_mass_spring(self.n_obj, self.meshing)
        else: raise NotImplementedError("System is not supported")

def edge_index_mass_spring(n_obj: list[int], mesh: Mesh) -> tc.Tensor:
    """
    Returns edge_index of shape (2, n_edges) of torch Tensor type to be used in scatter operations
    """
    assert len(n_obj) > 0
    if len(n_obj) == 1:
        [Nx] = n_obj
        return edge_index_chain(Nx)
    elif len(n_obj) == 2:
        [Nx, Ny] = n_obj
        return edge_index_lattice(Nx, Ny, mesh)
    else: raise NotImplementedError("System is not supported")

def edge_index_chain(Nx: int) -> tc.Tensor:
    """
    Edge indices for the 1D coupled system.
    """
    assert Nx > 1

    edge_index = []

    # x-direction
    for i in range(Nx - 1):
        edge_index.append([i + 1, i]) # towards lower index by default

    edge_index = tc.tensor(edge_index).T
    return edge_index

def edge_index_lattice(Nx: int, Ny: int, mesh: Mesh) -> tc.Tensor:
    """
    Edge indices for the 2D coupled system.
    """
    assert Nx > 1 and Ny > 1

    edge_index = []

    # x-direction
    for i in range(Nx - 1):
        for j in range(Ny):
            node_left = i * Ny + j
            node_right = (i + 1) * Ny + j
            edge_index.append([node_right, node_left]) # towards lower index by default

    # y-direction
    for i in range(Nx):
        for j in range(Ny - 1):
            node_down = i * Ny + j
            node_up = i * Ny + (j + 1)
            edge_index.append([node_down, node_up]) # towards lower index by default

    edge_index = tc.tensor(edge_index).T
    return edge_index

def H_mass_spring(x, n_obj, meshing: Mesh, mass, spring_constant, system="chain", as_local=False, as_separate=False):
    """
    Args:
        x of shape (n_points, *n_obj, 2*dof)
    Returns:
        Hamiltonian of shape (n_points, 1) by default
    """
    n_points = len(x)
    q, p = np.split(x, 2, axis=-1) # (n_points, *n_obj, dof) each

    kinetic = (p**2).sum(axis=-1, keepdims=True) # (n_points, *n_obj, 1)
    kinetic = 0.5 * (kinetic / mass) # (n_points, *n_obj, 1)

    if system == "chain":
        [N] = n_obj
        potential = np.zeros((n_points, N, 1)) # distribute the potential energy to the connected objects
        # compute spring potentials
        spring_potential = 0.5 * np.sum((q[:, 1:] - q[:, :-1])**2, axis=-1, keepdims=True) # (n_points, N-1, 1)
        spring_potential = spring_potential * spring_constant # (n_points, N, 1)

        # distribute potential equally to connected objects
        potential = np.zeros((n_points, N, 1))
        potential[:, :-1] += 0.5 * spring_potential
        potential[:, 1:]  += 0.5 * spring_potential
    elif system == "lattice":
        [Nx, Ny] = n_obj
        # distribute potentials 'equally' to connected objects depending on the mesh type
        potential = np.zeros((n_points, Nx, Ny, 1))
        if meshing.has_xy_springs():
            # compute for springs along x-axis
            spring_potential_x = 0.5 * np.sum((q[:, 1:] - q[:, :-1])**2, axis=-1, keepdims=True) # (n_points, Nx-1, Ny, 1)
            spring_potential_x = spring_potential_x * spring_constant
            # compute for springs along y-axis
            spring_potential_y = 0.5 * np.sum((q[:, :, 1:] - q[:, :, :-1])**2, axis=-1, keepdims=True) # (n_points, Nx, Ny-1, 1)
            spring_potential_y = spring_potential_y * spring_constant

            # distribute now potentials equally to connected objects
            potential[:, :-1] += 0.5 * spring_potential_x
            potential[:, 1:]  += 0.5 * spring_potential_x
            potential[:, :, :-1] += 0.5 * spring_potential_y
            potential[:, :, 1:]  += 0.5 * spring_potential_y
        else:
            raise NotImplementedError(f"Mesh type {meshing} is not implemented yet")

        # Add springs along 'down' diagonals
        if meshing.has_down_diag():
            # comptue for springs along down-diagonal
            spring_potential_down_diag = 0.5 * np.sum((q[:, 1:, 1:] - q[:, :-1, :-1])**2, axis=-1, keepdims=True)
            spring_potential_down_diag = spring_potential_down_diag * spring_constant # (n_points, Nx-1, Ny-1, 1)

            # distribute
            potential[:, :-1, :-1] += 0.5 * spring_potential_down_diag
            potential[:, 1:, 1:]   += 0.5 * spring_potential_down_diag

        # Add springs along 'up' diagonals
        if meshing.has_up_diag():
            # comptue for springs along up-diagonal
            spring_potential_up_diag = 0.5 * np.sum((q[:, :-1, 1:] - q[:, 1:, :-1])**2, axis=-1, keepdims=True)
            spring_potential_up_diag = spring_potential_up_diag * spring_constant # (n_points, Nx-1, Ny-1, 1)

            # distribute
            potential[:, 1:, :-1] += 0.5 * spring_potential_up_diag
            potential[:, :-1, 1:] += 0.5 * spring_potential_up_diag

        if np.all(potential == 0): raise ValueError("Something went wrong, or mesh type is not supported. Try either 'rectangular', 'diagonal', or 'cross'")
    else:
        raise NotImplementedError()

    sum_indices = tuple(range(1, 1 + len(n_obj))) # sum over nodes
    if as_separate and as_local:
        return kinetic, potential # (n_points, N, 1) each
    elif as_separate and not as_local:
        global_kinetic = np.sum(kinetic, axis=sum_indices) # (n_points, 1)
        global_potential = np.sum(potential, axis=sum_indices) # (n_points, 1)
        return global_kinetic, global_potential
    elif not as_separate and as_local:
        return kinetic + potential # (n_points, N, 1)
    elif not as_separate and not as_local:
        local_hamiltonians = kinetic + potential # (n_points, N, 1)
        return np.sum(local_hamiltonians, axis=sum_indices) # (n_points, 1)
    raise NotImplementedError("Hamiltonian is not implemented for the current state.")

def grad_H_mass_spring(x, dof, meshing: Mesh, mass, spring_constant, system="chain"):
    q, p = np.split(x, 2, axis=-1) # (n_points, *n_obj, dof) each

    kinetic = (p**2).sum(axis=-1, keepdims=True) # (n_points, *n_obj, 1)
    kinetic = 0.5 * (kinetic / mass) # (n_points, *n_obj, 1)
    dedx = np.zeros_like(x, dtype=x.dtype) # (n_points, *n_obj, 2*dof)
    dedx[..., dof:] = p / mass # dedp = p

    dedq = np.zeros_like(q, dtype=q.dtype) # of shape (n_points, *n_obj, dof)

    if system == "chain":
        # internal points
        dedq[:, 1:-1] = spring_constant*(q[:, 1:-1] - q[:, :-2]) - spring_constant*(q[:, 2:] - q[:, 1:-1])
        # boundaries
        dedq[:, 0] = -spring_constant*(q[:, 1] - q[:, 0])
        dedq[:, -1] = spring_constant*(q[:, -1] - q[:, -2])
    elif system == "lattice":
        # Compute for springs along x and y directions affecting dedq
        if meshing.has_xy_springs():
            # internal nodes
            dedq[:, 1:-1]    += spring_constant*(q[:, 1:-1] - q[:, :-2]) - spring_constant*(q[:, 2:] - q[:, 1:-1])
            dedq[:, :, 1:-1] += spring_constant*(q[:, :, 1:-1] - q[:, :, :-2]) - spring_constant*(q[:, :, 2:] - q[:, :, 1:-1])
            # boundary nodes (at x-axis)
            dedq[:, 0]       -= spring_constant*(q[:, 1] - q[:, 0])       # left edge
            dedq[:, -1]      += spring_constant*(q[:, -1] - q[:, -2])     # right edge
            # boundary nodes (at y-axis)
            dedq[:, :, 0]    -= spring_constant*(q[:, :, 1] - q[:, :, 0])     # top edge
            dedq[:, :, -1]   += spring_constant*(q[:, :, -1] - q[:, :, -2])   # bottom edge
        else:
            raise NotImplementedError(f"Mesh type {meshing} is not implemented yet")

        # Compute for springs along down-diagonals affecting dedq
        if meshing.has_down_diag():
            # internal nodes
            dedq[:, 1:-1, 1:-1] += spring_constant*(q[:, 1:-1, 1:-1] - q[:, :-2, :-2]) - spring_constant*(q[:, 2:, 2:] - q[:, 1:-1, 1:-1])
            # boundary nodes at edges excluding corners
            dedq[:, 0, 1:-1]    -= spring_constant*(q[:, 1, 2:] - q[:, 0, 1:-1])      # left edge
            dedq[:, 1:-1, 0]    -= spring_constant*(q[:, 2:, 1] - q[:, 1:-1, 0])      # top edge
            dedq[:, -1, 1:-1]   += spring_constant*(q[:, -1, 1:-1] - q[:, -2, :-2])   # right edge
            dedq[:, 1:-1,-1]    += spring_constant*(q[:, 1:-1, -1] - q[:, :-2, -2])   # bottom edge
            # corners
            dedq[:, 0, 0]       -= spring_constant*(q[:, 1, 1] - q[:, 0, 0])          # top-left corner
            dedq[:, -1, -1]     += spring_constant*(q[:, -1, -1] - q[:, -2, -2])      # bottom-right corner

        # Comptue for springs along up-diagonals affecting dedq
        if meshing.has_up_diag():
            # internal nodes
            dedq[:, 1:-1, 1:-1] += spring_constant*(q[:, 1:-1, 1:-1] - q[:, 2:, :-2]) - spring_constant*(q[:, :-2, 2:] - q[:, 1:-1, 1:-1])
            # boundary nodes
            dedq[:, 0, 1:-1]    += spring_constant*(q[:, 0, 1:-1] - q[:, 1, :-2])     # left edge
            dedq[:, 1:-1, 0]    -= spring_constant*(q[:, :-2, 1] - q[:, 1:-1, 0])     # top edge
            dedq[:, -1, 1:-1]   -= spring_constant*(q[:, -2, 2:] - q[:, -1, 1:-1])    # right edge
            dedq[:, 1:-1, -1]   += spring_constant*(q[:, 1:-1, -1] - q[:, 2:, -2])    # bottom edge
            # corners
            dedq[:, -1, 0]      -= spring_constant*(q[:, -2, 1] - q[:, -1, 0])        # top-right corner
            dedq[:, 0, -1]      += spring_constant*(q[:, 0, -1] - q[:, 1, -2])        # bottom-right corner

        if np.all(dedq == 0): raise ValueError("Something went wrong, or mesh type is not supported. Try either 'rectangular', 'diagonal', or 'cross'")
    else:
        raise NotImplementedError()

    dedx[..., :dof] = dedq
    return dedx
