# Code taken from GeoTDM: https://github.com/hanjq17/GeoTDM

import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import numpy as np
import argparse
from joblib import Parallel, delayed
import warnings


class SpringSim(object):
    def __init__(
        self,
        n_balls=5,
        box_size=5.0,
        loc_std=0.5,
        vel_norm=0.5,
        interaction_strength=0.1,
        noise_var=0.0,
    ):
        self.n_balls = n_balls
        self.box_size = box_size
        self.loc_std = loc_std
        self.vel_norm = vel_norm
        self.interaction_strength = interaction_strength
        self.noise_var = noise_var

        self._spring_types = np.array([0.0, 0.5, 1.0])
        self._delta_T = 0.001
        self._max_F = 0.1 / self._delta_T
        self.dim = 3

    def _energy(self, loc, vel, edges):
        # disables division by zero warning, since I fix it with fill_diagonal
        with np.errstate(divide="ignore"):
            K = 0.5 * (vel**2).sum()
            U = 0
            for i in range(loc.shape[1]):
                for j in range(loc.shape[1]):
                    if i != j:
                        r = loc[:, i] - loc[:, j]
                        dist = np.sqrt((r**2).sum())
                        U += 0.5 * self.interaction_strength * edges[i, j] * (dist**2) / 2
            return U + K

    def _clamp(self, loc, vel):
        """
        :param loc: 2xN location at one time stamp
        :param vel: 2xN velocity at one time stamp
        :return: location and velocity after hiting walls and returning after
            elastically colliding with walls
        """
        assert np.all(loc < self.box_size * 3)
        assert np.all(loc > -self.box_size * 3)

        over = loc > self.box_size
        loc[over] = 2 * self.box_size - loc[over]
        assert np.all(loc <= self.box_size)

        # assert(np.all(vel[over]>0))
        vel[over] = -np.abs(vel[over])

        under = loc < -self.box_size
        loc[under] = -2 * self.box_size - loc[under]
        # assert (np.all(vel[under] < 0))
        assert np.all(loc >= -self.box_size)
        vel[under] = np.abs(vel[under])

        return loc, vel

    def _l2(self, A, B):
        """
        Input: A is a Nxd matrix
               B is a Mxd matirx
        Output: dist is a NxM matrix where dist[i,j] is the square norm
            between A[i,:] and B[j,:]
        i.e. dist[i,j] = ||A[i,:]-B[j,:]||^2
        """
        A_norm = (A**2).sum(axis=1).reshape(A.shape[0], 1)
        B_norm = (B**2).sum(axis=1).reshape(1, B.shape[0])
        dist = A_norm + B_norm - 2 * A.dot(B.transpose())
        return dist

    def sample_trajectory(self, T=10000, sample_freq=10, spring_prob=[1.0 / 2, 0, 1.0 / 2]):
        n = self.n_balls
        assert T % sample_freq == 0
        T_save = int(T / sample_freq - 1)
        diag_mask = np.ones((n, n), dtype=bool)
        np.fill_diagonal(diag_mask, 0)
        counter = 0
        # Sample edges
        edges = np.random.choice(
            self._spring_types, size=(self.n_balls, self.n_balls), p=spring_prob
        )
        edges = np.tril(edges) + np.tril(edges, -1).T
        np.fill_diagonal(edges, 0)
        # Initialize location and velocity
        loc = np.zeros((T_save, self.dim, n))
        vel = np.zeros((T_save, self.dim, n))
        loc_next = np.random.randn(self.dim, n) * self.loc_std
        vel_next = np.random.randn(self.dim, n)
        v_norm = np.sqrt((vel_next**2).sum(axis=0)).reshape(1, -1)
        vel_next = vel_next * self.vel_norm / v_norm
        loc[0, :, :], vel[0, :, :] = self._clamp(loc_next, vel_next)

        # disables division by zero warning, since I fix it with fill_diagonal
        with np.errstate(divide="ignore"):
            forces_size = -self.interaction_strength * edges
            np.fill_diagonal(forces_size, 0)  # self forces are zero (fixes division by zero)
            F = (
                forces_size.reshape(1, n, n)
                * np.concatenate(
                    (
                        np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, n, n),
                        np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, n, n),
                        np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape(1, n, n),
                    )
                )
            ).sum(axis=-1)
            F[F > self._max_F] = self._max_F
            F[F < -self._max_F] = -self._max_F

            vel_next += self._delta_T * F
            # run leapfrog
            for i in range(1, T):
                loc_next += self._delta_T * vel_next
                # loc_next, vel_next = self._clamp(loc_next, vel_next)

                if i % sample_freq == 0:
                    loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next
                    counter += 1

                forces_size = -self.interaction_strength * edges
                np.fill_diagonal(forces_size, 0)
                # assert (np.abs(forces_size[diag_mask]).min() > 1e-10)

                F = (
                    forces_size.reshape(1, n, n)
                    * np.concatenate(
                        (
                            np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, n, n),
                            np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, n, n),
                            np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape(1, n, n),
                        )
                    )
                ).sum(axis=-1)
                F[F > self._max_F] = self._max_F
                F[F < -self._max_F] = -self._max_F
                vel_next += self._delta_T * F
            # Add noise to observations
            loc += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var
            vel += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var
            return loc, vel, edges, np.zeros(shape=(self.n_balls, 1))


class ChargedParticlesSim(object):
    def __init__(
        self,
        n_balls=5,
        box_size=5.0,
        loc_std=1.0,
        vel_norm=0.5,
        interaction_strength=1.0,
        noise_var=0.0,
    ):
        self.n_balls = n_balls
        self.box_size = box_size
        self.loc_std = loc_std
        self.loc_std = loc_std * (float(n_balls) / 5.0) ** (1 / 3)
        print(self.loc_std)
        self.vel_norm = vel_norm
        self.interaction_strength = interaction_strength
        self.noise_var = noise_var

        self._charge_types = np.array([-1.0, 0.0, 1.0])
        self._delta_T = 0.001
        self._max_F = 0.1 / self._delta_T
        self.dim = 3

    def _l2(self, A, B):
        """
        Input: A is a Nxd matrix
               B is a Mxd matirx
        Output: dist is a NxM matrix where dist[i,j] is the square norm
            between A[i,:] and B[j,:]
        i.e. dist[i,j] = ||A[i,:]-B[j,:]||^2
        """
        A_norm = (A**2).sum(axis=1).reshape(A.shape[0], 1)
        B_norm = (B**2).sum(axis=1).reshape(1, B.shape[0])
        dist = A_norm + B_norm - 2 * A.dot(B.transpose())
        return dist

    def _energy(self, loc, vel, edges):
        # disables division by zero warning, since I fix it with fill_diagonal
        with np.errstate(divide="ignore"):
            K = 0.5 * (vel**2).sum()
            U = 0
            for i in range(loc.shape[1]):
                for j in range(loc.shape[1]):
                    if i != j:
                        r = loc[:, i] - loc[:, j]
                        dist = np.sqrt((r**2).sum())
                        U += 0.5 * self.interaction_strength * edges[i, j] / dist
            return U + K

    def _clamp(self, loc, vel):
        """
        :param loc: 2xN location at one time stamp
        :param vel: 2xN velocity at one time stamp
        :return: location and velocity after hiting walls and returning after
            elastically colliding with walls
        """
        assert np.all(loc < self.box_size * 3)
        assert np.all(loc > -self.box_size * 3)

        over = loc > self.box_size
        loc[over] = 2 * self.box_size - loc[over]
        assert np.all(loc <= self.box_size)

        # assert(np.all(vel[over]>0))
        vel[over] = -np.abs(vel[over])

        under = loc < -self.box_size
        loc[under] = -2 * self.box_size - loc[under]
        # assert (np.all(vel[under] < 0))
        assert np.all(loc >= -self.box_size)
        vel[under] = np.abs(vel[under])

        return loc, vel

    def sample_trajectory(self, T=10000, sample_freq=10, charge_prob=[1.0 / 2, 0, 1.0 / 2]):
        n = self.n_balls
        assert T % sample_freq == 0
        T_save = int(T / sample_freq - 1)
        diag_mask = np.ones((n, n), dtype=bool)
        np.fill_diagonal(diag_mask, 0)
        counter = 0
        # Sample edges
        charges = np.random.choice(self._charge_types, size=(self.n_balls, 1), p=charge_prob)
        edges = charges.dot(charges.transpose())
        # Initialize location and velocity
        loc = np.zeros((T_save, self.dim, n))
        vel = np.zeros((T_save, self.dim, n))
        loc_next = np.random.randn(self.dim, n) * self.loc_std
        vel_next = np.random.randn(self.dim, n)
        v_norm = np.sqrt((vel_next**2).sum(axis=0)).reshape(1, -1)
        vel_next = vel_next * self.vel_norm / v_norm
        loc[0, :, :], vel[0, :, :] = self._clamp(loc_next, vel_next)

        # disables division by zero warning, since I fix it with fill_diagonal
        with np.errstate(divide="ignore"):
            # half step leapfrog
            l2_dist_power3 = np.power(
                self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0
            )

            # size of forces up to a 1/|r| factor
            # since I later multiply by an unnormalized r vector
            forces_size = self.interaction_strength * edges / l2_dist_power3
            np.fill_diagonal(forces_size, 0)  # self forces are zero (fixes division by zero)
            assert np.abs(forces_size[diag_mask]).min() > 1e-10
            F = (
                forces_size.reshape(1, n, n)
                * np.concatenate(
                    (
                        np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, n, n),
                        np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, n, n),
                        np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape(1, n, n),
                    )
                )
            ).sum(axis=-1)
            F[F > self._max_F] = self._max_F
            F[F < -self._max_F] = -self._max_F

            vel_next += self._delta_T * F
            # run leapfrog
            for i in range(1, T):
                loc_next += self._delta_T * vel_next
                # loc_next, vel_next = self._clamp(loc_next, vel_next)

                if i % sample_freq == 0:
                    loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next
                    counter += 1

                l2_dist_power3 = np.power(
                    self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0
                )
                forces_size = self.interaction_strength * edges / l2_dist_power3
                np.fill_diagonal(forces_size, 0)
                # assert (np.abs(forces_size[diag_mask]).min() > 1e-10)

                F = (
                    forces_size.reshape(1, n, n)
                    * np.concatenate(
                        (
                            np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, n, n),
                            np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, n, n),
                            np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape(1, n, n),
                        )
                    )
                ).sum(axis=-1)
                F[F > self._max_F] = self._max_F
                F[F < -self._max_F] = -self._max_F
                vel_next += self._delta_T * F
            # Add noise to observations
            loc += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var
            vel += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var
            return loc, vel, edges, charges


class GravitySim(object):
    def __init__(
        self,
        n_balls=100,
        loc_std=1,
        vel_norm=0.5,
        interaction_strength=1,
        noise_var=0,
        dt=0.001,
        softening=0.1,
    ):
        self.n_balls = n_balls
        self.loc_std = loc_std
        self.vel_norm = vel_norm
        self.interaction_strength = interaction_strength
        self.noise_var = noise_var
        self.dt = dt
        self.softening = softening

        self.dim = 3

    def compute_acceleration(self, pos, mass, G, softening):
        # positions r = [x,y,z] for all particles
        x = pos[:, 0:1]
        y = pos[:, 1:2]
        z = pos[:, 2:3]

        # matrix that stores all pairwise particle separations: r_j - r_i
        dx = x.T - x
        dy = y.T - y
        dz = z.T - z

        # matrix that stores 1/r^3 for all particle pairwise particle separations
        inv_r3 = dx**2 + dy**2 + dz**2 + softening**2
        inv_r3[inv_r3 > 0] = inv_r3[inv_r3 > 0] ** (-1.5)

        ax = G * (dx * inv_r3) @ mass
        ay = G * (dy * inv_r3) @ mass
        az = G * (dz * inv_r3) @ mass

        # pack together the acceleration components
        a = np.hstack((ax, ay, az))
        return a

    def _energy(self, pos, vel, mass, G):
        # Kinetic Energy:
        KE = 0.5 * np.sum(np.sum(mass * vel**2))

        # Potential Energy:

        # positions r = [x,y,z] for all particles
        x = pos[:, 0:1]
        y = pos[:, 1:2]
        z = pos[:, 2:3]

        # matrix that stores all pairwise particle separations: r_j - r_i
        dx = x.T - x
        dy = y.T - y
        dz = z.T - z

        # matrix that stores 1/r for all particle pairwise particle separations
        inv_r = np.sqrt(dx**2 + dy**2 + dz**2)
        inv_r[inv_r > 0] = 1.0 / inv_r[inv_r > 0]

        # sum over upper triangle, to count each interaction only once
        PE = G * np.sum(np.sum(np.triu(-(mass * mass.T) * inv_r, 1)))

        return KE, PE, KE + PE

    def sample_trajectory(self, T=10000, sample_freq=10):
        assert T % sample_freq == 0

        T_save = int(T / sample_freq)

        N = self.n_balls

        pos_save = np.zeros((T_save, N, self.dim))
        vel_save = np.zeros((T_save, N, self.dim))
        force_save = np.zeros((T_save, N, self.dim))

        # Specific sim parameters
        mass = np.ones((N, 1))
        t = 0
        pos = np.random.randn(N, self.dim)  # randomly selected positions and velocities
        vel = np.random.randn(N, self.dim)

        # Convert to Center-of-Mass frame
        vel -= np.mean(mass * vel, 0) / np.mean(mass)

        # calculate initial gravitational accelerations
        acc = self.compute_acceleration(pos, mass, self.interaction_strength, self.softening)

        for i in range(T):
            if i % sample_freq == 0:
                pos_save[int(i / sample_freq)] = pos
                vel_save[int(i / sample_freq)] = vel
                force_save[int(i / sample_freq)] = acc * mass

            # (1/2) kick
            vel += acc * self.dt / 2.0

            # drift
            pos += vel * self.dt

            # update accelerations
            acc = self.compute_acceleration(
                pos, mass, self.interaction_strength, self.softening
            )

            # (1/2) kick
            vel += acc * self.dt / 2.0

            # update time
            t += self.dt

        # Add noise to observations
        pos_save += np.random.randn(T_save, N, self.dim) * self.noise_var
        vel_save += np.random.randn(T_save, N, self.dim) * self.noise_var
        force_save += np.random.randn(T_save, N, self.dim) * self.noise_var
        return pos_save, vel_save, force_save, mass


if __name__ == "__main__":
    from tqdm import tqdm

    color_map = "summer"
    cmap = plt.get_cmap(color_map)

    np.random.seed(43)

    sim = GravitySim(n_balls=100, loc_std=1)

    t = time.time()
    loc, vel, force, mass = sim.sample_trajectory(T=5000, sample_freq=1)

    print("Simulation time: {}".format(time.time() - t))
    plt.figure()
    axes = plt.gca()
    axes.set_xlim([-4.0, 4.0])  # type: ignore
    axes.set_ylim([-4.0, 4.0])  # type: ignore
    # for i in range(loc.shape[-2]):
    #     plt.plot(loc[:, i, 0], loc[:, i, 1], alpha=0.1, linewidth=1)
    #     plt.plot(loc[0, i, 0], loc[0, i, 1], 'o')

    offset = 4000
    N_frames = loc.shape[0] - offset
    N_particles = loc.shape[-2]

    for i in tqdm(range(N_particles)):
        color = cmap(i / N_particles)
        # for j in range(loc.shape[0]-2):
        for j in range(offset, offset + N_frames):
            plt.plot(
                loc[j : j + 2, i, 0],
                loc[j : j + 2, i, 1],
                alpha=0.2 + 0.7 * ((j - offset) / N_frames) ** 4,
                linewidth=1,
                color=color,
            )
        plt.plot(loc[-1, i, 0], loc[-1, i, 1], "o", markersize=3, color=color)
    plt.axis("off")
    # plt.figure()
    # energies = [sim._energy(loc[i, :, :], vel[i, :, :], mass, sim.interaction_strength) for i in
    #             range(loc.shape[0])]
    # plt.plot(energies)
    plt.show()


warnings.filterwarnings("ignore")

"""
nbody_small:   python3 -u generate_dataset.py --simulation=charged --num-train 10000 --seed 43 --suffix small
gravity_small: python3 -u generate_dataset.py --simulation=gravity --num-train 10000 --seed 43 --suffix small
"""

parser = argparse.ArgumentParser()
parser.add_argument(
    "--simulation", type=str, default="gravity", help="What simulation to generate."
)
parser.add_argument(
    "--num-train", type=int, default=500, help="Number of training simulations to generate."
)
parser.add_argument(
    "--num-valid", type=int, default=300, help="Number of validation simulations to generate."
)
parser.add_argument(
    "--num-test", type=int, default=300, help="Number of test simulations to generate."
)
parser.add_argument("--length", type=int, default=10000, help="Length of trajectory.")
parser.add_argument(
    "--length_test", type=int, default=10000, help="Length of test set trajectory."
)
parser.add_argument(
    "--sample-freq", type=int, default=100, help="How often to sample the trajectory."
)
parser.add_argument(
    "--n_balls", type=int, default=75, help="Number of balls in the simulation."
)
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
parser.add_argument("--initial_vel", type=int, default=1, help="consider initial velocity")
parser.add_argument("--suffix", type=str, default="", help="add a suffix to the name")
parser.add_argument(
    "--n_workers", type=int, default=16, help="number of workers for parallelization"
)

args = parser.parse_args()

initial_vel_norm = 0.5
if not args.initial_vel:
    initial_vel_norm = 1e-16

if args.simulation == "springs":
    sim = SpringSim(noise_var=0.0, n_balls=args.n_balls)
    suffix = "_springs"
elif args.simulation == "charged":
    sim = ChargedParticlesSim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm)
    suffix = "_charged"
elif args.simulation == "gravity":
    sim = GravitySim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm)
    suffix = "_gravity"
else:
    raise ValueError("Simulation {} not implemented".format(args.simulation))

suffix += str(args.n_balls) + "_initvel%d" % args.initial_vel + args.suffix
np.random.seed(args.seed)


def para_comp(length, sample_freq):
    loc, vel, edges, charges = sim.sample_trajectory(T=length, sample_freq=sample_freq)
    return loc, vel, edges, charges


def generate_dataset(num_sims, length, sample_freq):
    results = Parallel(n_jobs=args.n_workers)(
        delayed(para_comp)(length, sample_freq) for i in tqdm(range(num_sims))
    )
    loc_all, vel_all, edges_all, charges_all = zip(*results)
    # print(f'total trials: {cnt:d}, samples: {len(loc_all):d}', cnt)

    return loc_all, vel_all, edges_all, charges_all


if __name__ == "__main__":
    print("Generating {} training simulations".format(args.num_train))
    loc_train, vel_train, edges_train, charges_train = generate_dataset(
        args.num_train, args.length, args.sample_freq
    )

    print("Generating {} validation simulations".format(args.num_valid))
    loc_valid, vel_valid, edges_valid, charges_valid = generate_dataset(
        args.num_valid, args.length, args.sample_freq
    )

    print("Generating {} test simulations".format(args.num_test))
    loc_test, vel_test, edges_test, charges_test = generate_dataset(
        args.num_test, args.length_test, args.sample_freq
    )

    np.save("loc_train" + suffix + ".npy", loc_train)
    np.save("vel_train" + suffix + ".npy", vel_train)
    np.save("edges_train" + suffix + ".npy", edges_train)
    np.save("charges_train" + suffix + ".npy", charges_train)

    np.save("loc_valid" + suffix + ".npy", loc_valid)
    np.save("vel_valid" + suffix + ".npy", vel_valid)
    np.save("edges_valid" + suffix + ".npy", edges_valid)
    np.save("charges_valid" + suffix + ".npy", charges_valid)

    np.save("loc_test" + suffix + ".npy", loc_test)
    np.save("vel_test" + suffix + ".npy", vel_test)
    np.save("edges_test" + suffix + ".npy", edges_test)
    np.save("charges_test" + suffix + ".npy", charges_test)
