import numpy as np
import torch as tc

from typing import override
from dataclasses import dataclass

from src.utils import flatten, unflatten, stormer_verlet_step
from .nbody_system import NBodySystem

@dataclass
class LennardJones(NBodySystem):
    n_obj: int                  # number of objects (particles) in the system
    dof: int                    # degrees of freedom

    # properties of the system (for each object)
    mass: float = 1.0
    epsilon: float = 1.0
    sigma: float = 1.0
    cutoff: float = np.inf

    def __post_init__(self):
        assert self.dof == 2, "Lennard-Jones is only supported in 2D for now, other DOFs are not implemented."
        super().__post_init__()

    def L(self) -> np.ndarray:
        Ls = np.zeros((self.n_features, self.n_features), self.dtype)
        total_num_objects = self.n_obj
        for i in range(total_num_objects):
            for j in range(self.dof):
                Ls[self.dof*i + j, self.dof*total_num_objects + self.dof*i + j] =  1.
                Ls[self.dof*total_num_objects + self.dof*i + j, self.dof*i + j] = -1.
        return Ls

    def flatten(self, x) -> np.ndarray:
        return flatten(x)

    def unflatten(self, x) -> np.ndarray:
        return unflatten(x, [self.n_obj], self.dof)

    @override
    def to_array(self, flatten=False) -> np.ndarray:
        """
        Returns array representation of the system of shape (n_points, *n_obj, 2*dof)
        """
        assert self.q.shape == (self.n_points, self.n_obj * self.dof)
        assert self.p.shape == (self.n_points, self.n_obj * self.dof)

        x = np.concatenate([self.q, self.p], axis=-1)
        if not flatten: x = self.unflatten(x)
        return x

    @override
    def from_array(self, x: np.ndarray):
        assert x.shape == (self.n_points, self.n_obj, 2*self.dof), \
               f"Expected shape {(self.n_points, self.n_obj, 2*self.dof)} got {x.shape}"

        q, p = np.split(x, 2, axis=-1) # shape (n_points, *n_obj, dof) each
        self.q = q.reshape(self.n_points, self.n_obj * self.dof)
        self.p = p.reshape(self.n_points, self.n_obj * self.dof)

    @override
    def H(self, as_local=False, as_separate=False):
        assert not as_local
        return H_lennard_jones(self.to_array(), self.mass, self.epsilon, self.sigma, self.cutoff, as_separate)

    @override
    def grad_H(self, flatten=False, noise_scale=0.0, rng=None):
        if not rng is None:
            noise = rng.normal(loc=0.0, scale=noise_scale, size=(self.n_points, self.n_obj, 2*self.dof)).astype(self.dtype)
        else:
            noise = 0.0

        dedx = grad_H_lennard_jones(self.to_array() + noise, self.mass, self.epsilon, self.sigma, self.cutoff)
        if flatten: dedx = self.flatten(dedx)
        return dedx

    def integrate(self):
        raise NotImplementedError("integrate not implemented for LennardJones")

    def edge_index(self, degree_limit=None):
        """Outputs edge_index given positions (q) of shape (n_obj, dof) and
        cutoff (float). Note that this function does not limit the number of degrees a node can have.
        Returns edge_index of shape (2, n_edges)
        """
        edge_index = []
        q = self.to_array(flatten=False)[..., :self.dof]
        for state in q:
            edge_index_state = edge_index_radius(state, self.cutoff, degree_limit=degree_limit) # using no cutoff with same number of particles in every graph so the number of edges is fixed
            edge_index.append(edge_index_state)
        return edge_index

def edge_index_radius(q, cutoff, degree_limit=None):
    """Outputs edge_index given positions (q) of shape (n_obj, dof) and
    cutoff (float). Note that this function does not limit the number of degrees a node can have.
    Returns edge_index of shape (2, n_edges)
    """
    n_obj = len(q)
    if degree_limit is None:
        edge_index = []
        for i in range(n_obj):
            for j in range(i + 1, n_obj):
                dist = r(q[i], q[j])
                if dist < cutoff:
                    edge_index.append([j, i])
    else:
        degrees = { i: 0 for i in range(n_obj) }
        pairs = [] # store indices with their distances
        for i in range(n_obj):
            for j in range(i + 1, n_obj):
                dist = r(q[i], q[j])
                if dist < cutoff:
                    pairs.append([j, i, dist])

        # sort pairs by distance
        pairs.sort(key=lambda x: x[-1])

        edge_index = []
        for j, i, _ in pairs:
            if degrees[j] < degree_limit and degrees[i] < degree_limit:
                edge_index.append([j, i])
                degrees[j] += 1
                degrees[i] += 1

    # return depending on whether list is empty
    if len(edge_index) == 0:
        return []

    # convert to torch tensor
    edge_index = tc.tensor(edge_index).T
    return edge_index

def r(q_i, q_j, return_delta=False):
    """Computes distance (L2 norm) given q_i and q_j of shape (dof) each
    Outputs |q_j - q_i| (scalar)
    """
    delta = q_j - q_i
    dist = np.linalg.norm(delta)
    if np.isclose(dist, 0.0):
        raise ValueError("Particles are overlapping or self-interacting (|q_j - q_i|)")
    if return_delta:
        return dist, delta

    return dist

def get_equil_r(sig):
    """Returns the equilibrium distance where particles do not move given
    sig (scalar) sigma parameter in LJ
    """
    return 2**(1/6) / sig

def get_eq_lattice_x0(nx, ny, sig):
    """Returns lattice in 2D with equilibrium distances where particles along x and y axes are
    connected only, given
    nx (int)
    ny (int)
    sig (float) sigma parameter of Lennard Jones
    Returns positions q of shape (n_obj, 2)
    """
    r_eq = get_equil_r(sig)
    x = np.arange(nx) * r_eq
    y = np.arange(ny) * r_eq
    xx, yy = np.meshgrid(x, y)
    q = np.stack([xx.ravel(), yy.ravel()], axis=-1)
    return q

def H_lennard_jones(x, mass, eps, sig, cutoff, as_separate=False):
    """Computes the total Hamiltonian (kinetic part + potential part) given x of shape (n_graphs, n_obj, 2*dof)
    Outputs Hamiltonians of shape (n_graphs, 1)
    Args:
        x of shape (n_points, n_obj, 2*dof)
    Returns:
        Hamiltonian of shape (n_points, 1) by default
    """
    n_graphs = len(x)
    q, p = np.split(x, 2, axis=-1)  # each of shape (n_graphs, n_obj, dof)
    ke_total = np.empty((n_graphs, 1), dtype=x.dtype)
    pe_total = np.empty((n_graphs, 1), dtype=x.dtype)

    for idx in range(n_graphs):
        ke_total[idx] = kinetic_part(p[idx], mass).sum()
        pe_total[idx] = potential_part(q[idx], eps, sig, cutoff).sum()

    if as_separate:
        return ke_total, pe_total
    return ke_total + pe_total # of shape (n_graphs, 1)

def kinetic_part(p, mass):
    """Computes the kinetic energy given p of shape (n_obj, dof)
    Outputs kinetic energies of shape (n_obj, 1) as (0.5 * (|p|^2) / mass) for each object
    """
    return 0.5 * ((p**2).sum(axis=-1, keepdims=True) / mass) # (n_obj, 1)

def LJ(r, eps, sig):
    """Computes Lennard-Jones 12-6 potential given scalar r (distance)
    Outputs potential energy (scalar) as (4 * eps * ( (sig/r)^12 - (sig/r)^6 ))
    """
    inv_r = sig / r
    inv_r6 = inv_r ** 6
    inv_r12 = inv_r6 ** 2
    return 4 * eps * (inv_r12 - inv_r6) # scalar

def potential_part(q, eps, sig, cutoff=np.inf):
    """Computes potential part given q of shape (n_obj, dof)
    Outputs potential part of shape (n_obj, 1) representing 'local' potential parts
    """
    n_obj = len(q)
    pe = np.zeros((n_obj, 1), dtype=q.dtype)
    for i in range(n_obj):
        for j in range(i + 1, n_obj):
            # compute LJ(|q_j - q_i|)
            dist = r(q[i], q[j])
            #dist = np.clip(dist, 0.9, None) # you can also clip on the minima to avoid explosions
            if dist < cutoff:
                total_local_pe = LJ(dist, eps, sig)
                pe[j] += total_local_pe / 2.0
                pe[i] += total_local_pe / 2.0 # spread the potential evenly

    return pe # (n_obj, 1)

# Gradients

def grad_H_lennard_jones(x, mass, eps, sig, cutoff, box_size=None):
    """Computes the gradient of the Hamiltonian given x of shape (n_graphs, n_obj, 2*dof)
    Outputs gradient of the Hamiltonian of shape (n_graphs, n_obj, dof)
    """
    if box_size is not None: raise NotImplementedError("Periodic boundary condition is not implemented yet, but box_size was given")
    n_graphs = len(x)
    q, p = np.split(x, 2, axis=-1)  # each of shape (n_graphs, n_obj, dof)
    grad_ke = np.empty_like(p, dtype=x.dtype)      # kinetic energy gradient only affects momenta (p)
    grad_pe = np.empty_like(q, dtype=x.dtype)      # potential energy gradient only affects positions (q)

    for idx in range(n_graphs):
        grad_ke[idx] = grad_kinetic_part(p[idx], mass) # shape (n_obj, dof)
        grad_pe[idx] = grad_potential_part(q[idx], eps, sig, cutoff, box_size) # shape (n_obj, dof)

    return np.concatenate([grad_pe, grad_ke], axis=-1) # shape (n_graphs, n_obj, 2*dof) where n_features = 2*dof

def grad_kinetic_part(p, mass):
    """Computes gradient of the kinetic energy dT/dp given p of shape (n_obj, dof)
    Outputs gradient of shape (n_obj, dof) as (p / mass)
    """
    return p / mass # (n_obj, dof)

def grad_potential_part(q, eps, sig, cutoff=np.inf, box_size=None):
    """Computes the gradient of the potential part given q of shape (n_obj, dof)
    Outputs gradient of shape (n_obj, dof) with grad_LJ(.) * grad_r(.) as entries
    """
    if box_size is not None: raise NotImplementedError("Periodic boundary condition is not implemented yet, but box_size was given")
    n_obj = len(q)
    dq = np.zeros_like(q, dtype=q.dtype) # (n_obj, dof)
    for i in range(n_obj):
        for j in range(i + 1, n_obj):
            # compute |q_j - q_i|
            dist, delta = r(q[i], q[j], return_delta=True)
            # compute dLJ(.)/dq_j and dLJ(.)/dq_i
            #dist = np.clip(dist, 0.9, None) # you can also clip on the minima to avoid explosions
            if dist < cutoff:
                grad_lj_r = grad_LJ(dist, eps, sig) # scalar
                grad_r_qj = delta / dist # gradient of |q_j - q_i| w.r.t. q_j, of shape (dof)
                grad_r_qi = -grad_r_qj

                dq[j] += grad_r_qj * grad_lj_r  # for q_j
                dq[i] += grad_r_qi * grad_lj_r  # for q_i

    return dq # (n_obj, dof)

def grad_LJ(r, eps, sig):
    """Computes gradient of the Lennard-Jones 12-6 potential given distance r
    Outputs LJ gradient (scalar because r is scalar) as (4 * eps * (-12 * (sig^12 / r^13) + 6 * (sig^6 / r^7)))
    """
    inv_r13 = (sig**12) / (r**13)
    inv_r7 = (sig**6) / (r**7)
    return 4 * eps * (-12.0 * inv_r13 + 6.0 * inv_r7) # scalar

