import numpy as np
import torch as tc


def flatten(x):
    """
    Flattens a graph representation (tensor of shape length 3 due to nodes)
    into a tensor of shape length 2. For the N-body mass-spring system.
    Args:
        x of shape (n_points, *N, 2*dof) where N contains number of nodes in the graph structure along axes
    Returns:
        x of shape (n_points, -1) where q and p are stacked as [q_values, p_values]
    """
    assert len(x.shape) > 2, (
        f"Expected shape (n_points, *n_obj, 2*dof) but got shape of length 2."
    )
    n_points, *_, n_dim = x.shape
    assert n_dim % 2 == 0, "Expected 2*dof in the last dimension"
    dof = n_dim // 2
    q_values = x[..., :dof]
    p_values = x[..., dof:]
    q_values = q_values.reshape(n_points, -1)
    p_values = p_values.reshape(n_points, -1)
    return np.concatenate([q_values, p_values], axis=-1)

def flatten_TC(x: tc.Tensor) -> tc.Tensor:
    """
    Torch version of 'flatten' above.
    """
    assert len(x.shape) > 2, (
        f"Expected shape (n_points, *n_obj, 2*dof) but got shape of length 2."
    )
    n_points, _, n_dim = x.shape
    assert n_dim % 2 == 0, "Expected 2*dof in the last dimension"
    dof = n_dim // 2
    q_values = x[..., :dof]
    p_values = x[..., dof:]
    q_values = q_values.reshape(n_points, -1)
    p_values = p_values.reshape(n_points, -1)
    return tc.cat([q_values, p_values], dim=-1)

def unflatten(x, n_obj: list[int], dof: int):
    """
    'Unflattens' a tensor of shape length 2 into a tensor of shape length of 3 to represent
    nodes. For the N-body mass-spring system.
    Args:
        n_obj       Number of nodes along axes. E.g., [Nx] means Nx number of nodes along x-axis.
        dof         Degrees of freedom. E.g., for dof=2 you get q_x and q_y. for the generalized position.
        x           of shape (n_points, -1) where q and p are stacked as [q_values, p_values]
    Returns:
        reshaped x of shape (n_points, *N, 2*dof)
    """
    assert len(x.shape) == 2, "Expected at least shape (n_points, ) but got shape of length 2."
    n_points, _ = x.shape
    total_num_objects = np.prod(n_obj).item()

    assert x.shape[1] == 2 * dof * total_num_objects, (
        f"Expected shape (n_points, {2 * dof * total_num_objects}) but got {x.shape}"
    )

    # Split into q and p values
    q_values = x[:, : total_num_objects * dof].reshape(n_points, *n_obj, dof)
    p_values = x[:, total_num_objects * dof :].reshape(n_points, *n_obj, dof)

    # Concatenate along the last axis
    return np.concatenate([q_values, p_values], axis=-1)

def unflatten_TC(x: tc.Tensor, n_obj: list[int], dof: int) -> tc.Tensor:
    """
    Torch version of 'unflatten'
    """
    assert len(x.shape) == 2, "Expected at least shape (n_points, ) but got shape of length 2."
    n_points, _ = x.shape
    total_num_objects = np.prod(n_obj).item()

    assert x.shape[1] == 2 * dof * total_num_objects, (
        f"Expected shape (n_points, {2 * dof * total_num_objects}) but got {x.shape}"
    )

    # Split into q and p values
    q_values = x[:, : total_num_objects * dof].reshape(n_points, *n_obj, dof)
    p_values = x[:, total_num_objects * dof :].reshape(n_points, *n_obj, dof)

    # Concatenate along the last axis
    return tc.cat([q_values, p_values], dim=-1)
