import os

import numpy as np
from tqdm import tqdm


def create_spring(n_balls: int, num: int, save_name: str) -> None:
    num = int(num)
    def generate_dataset(sim, num_sims, length, sample_freq):
        loc_all = list()
        vel_all = list()
        edges_all = list()
        energy_all = list()

        for _ in tqdm(range(num_sims), dynamic_ncols=True):
            loc, vel, edges = sim.sample_trajectory(T=length, sample_freq=sample_freq)
            energies = np.array([
                sim._energy(loc[i, :, :], vel[i, :, :], edges)
                    for i in range(loc.shape[0])
                ])
            loc_all.append(loc)
            vel_all.append(vel)
            edges_all.append(edges)
            energy_all.append(energies)

        return np.concatenate([np.stack(loc_all), np.stack(vel_all)], axis=-1), \
            np.stack(edges_all), np.stack(energy_all)

    class SpringSim(object):
        def __init__(self, n_balls=5):
            self.n_balls = n_balls

            self.loc_std = 0.5
            self.vel_norm = 0.5
            self.interaction_strength = 0.1

            self._spring_types = np.array([0., 0.5, 1.])
            self._delta_T = 0.001
            self._max_F = 0.1 / self._delta_T

        def _energy(self, loc, vel, edges):
            with np.errstate(divide='ignore'):
                K = 0.5 * (vel ** 2).sum()
                U = 0
                for i in range(loc.shape[1]):
                    for j in range(loc.shape[1]):
                        if i != j:
                            r = loc[:, i] - loc[:, j]
                            dist = np.sqrt((r ** 2).sum())
                            U += 0.5 * self.interaction_strength * edges[i, j] * (dist ** 2) / 2
                return U + K

        def sample_trajectory(self, T=10000, sample_freq=10, spring_prob=[1. / 2, 0, 1. / 2]):
            n = self.n_balls
            assert (T % sample_freq == 0)
            T_save = int(T / sample_freq - 1)
            diag_mask = np.ones((n, n), dtype=bool)
            np.fill_diagonal(diag_mask, 0)
            counter = 0
            # Sample edges
            edges = np.random.choice(self._spring_types, size=(self.n_balls, self.n_balls), p=spring_prob)
            edges = np.tril(edges) + np.tril(edges, -1).T
            np.fill_diagonal(edges, 0)
            # Initialize location and velocity
            loc = np.zeros((T_save, 2, n))
            vel = np.zeros((T_save, 2, n))
            loc_next = np.random.randn(2, n) * self.loc_std
            vel_next = np.random.randn(2, n)
            v_norm = np.sqrt((vel_next ** 2).sum(axis=0)).reshape(1, -1)
            vel_next = vel_next * self.vel_norm / v_norm
            loc[0, :, :], vel[0, :, :] = loc_next, vel_next

            with np.errstate(divide='ignore'):
                forces_size = - self.interaction_strength * edges
                np.fill_diagonal(forces_size, 0)  # self forces are zero (fixes division by zero)
                F = (forces_size.reshape(1, n, n) *
                    np.concatenate((
                        np.subtract.outer(loc_next[0, :],
                                        loc_next[0, :]).reshape(1, n, n),
                        np.subtract.outer(loc_next[1, :],
                                        loc_next[1, :]).reshape(1, n, n)))).sum(
                    axis=-1)

                vel_next += self._delta_T * F

                for i in range(1, T):
                    loc_next += self._delta_T * vel_next

                    if i % sample_freq == 0:
                        loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next
                        counter += 1

                    forces_size = - self.interaction_strength * edges
                    np.fill_diagonal(forces_size, 0)

                    F = (
                        forces_size.reshape(1, n, n) * \
                        np.concatenate((
                            np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, n, n),
                            np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, n, n)
                        ))
                    ).sum(axis=-1)

                    vel_next += self._delta_T * F
                    
                return loc, vel, edges

    sim = SpringSim(n_balls=n_balls)

    save_folder = f'data'
    os.makedirs(save_folder, exist_ok=True)

    print("Generating {} simulations".format(num))
    feature_x, edges, energy = generate_dataset(sim, num, 5100, 100)

    if save_name.endswith('.npy'):
        save_name = save_name[:-4]

    np.save(os.path.join(save_folder, f'{save_name}_x_{n_balls}_spring.npy'), feature_x.reshape(*feature_x.shape[:2], -1))
    np.save(os.path.join(save_folder, f'{save_name}_edge_{n_balls}_spring.npy'), edges)
    np.save(os.path.join(save_folder, f'{save_name}_energy_{n_balls}_spring.npy'), energy)
    

if __name__ == '__main__':
    # np.random.seed(42)
    # create_spring(5, 5e4, 'train')

    np.random.seed(420)
    create_spring(5, 5e3, 'val')




