import os

import numpy as np
from tqdm import tqdm
from scipy.integrate import odeint


def create_3body(num: int, n_time: int, save_name: str) -> None:
    num = int(num)
    save_folder = 'data'
    os.makedirs(save_folder, exist_ok=True)

    def threeBodyEquations(w, t, k1, k2, m1, m2, m3):
        r1, r2, r3, v1, v2, v3 = w.reshape(6, 3)

        r12 = np.linalg.norm(r2 - r1)
        r13 = np.linalg.norm(r3 - r1)
        r23 = np.linalg.norm(r3 - r2)
    
        dv1bydt = k1 * m2 * (r2 - r1) / r12**3 + k1 * m3 * (r3 - r1) / r13**3
        dv2bydt = k1 * m1 * (r1 - r2) / r12**3 + k1 * m3 * (r3 - r2) / r23**3
        dv3bydt = k1 * m1 * (r1 - r3) / r13**3 + k1 * m2 * (r2 - r3) / r23**3
        dr1bydt = k2 * v1
        dr2bydt = k2 * v2
        dr3bydt = k2 * v3

        r_derivs = np.concatenate([dr1bydt, dr2bydt, dr3bydt])
        v_derivs = np.concatenate([dv1bydt, dv2bydt, dv3bydt])
        return np.concatenate([r_derivs, v_derivs])
        
    def generate_traj(n_time: int = 8) -> np.ndarray:
        clip_n = 50

        m1, m2, m3 = 1.0, 1.0, 1.0
        G = 6.67408e-11
        m_nd = 1.989e+30  
        r_nd = 5.326e+12
        v_nd = 30000
        t_nd = 79.91 * 365 * 24 * 3600 * 0.51

        k1 = G * t_nd * m_nd / (np.square(r_nd) * v_nd)
        k2 = v_nd * t_nd / r_nd

        r1 = np.array([0, 0, 0], dtype=np.float64)
        r2 = np.array([1.732, 3, 0], dtype=np.float64)
        r3 = np.array([3.464, 0, 0], dtype=np.float64)

        v1 = (np.random.random(3) - 0.5) / 10
        v2 = (np.random.random(3) - 0.5) / 10
        v3 = (np.random.random(3) - 0.5) / 10

        last_t = 5
        init_params = np.array([r1, r2, r3, v1, v2, v3]).flatten()
        time_span = np.linspace(0, last_t, clip_n * n_time + 1)

        traj = np.array(odeint(
            threeBodyEquations,
            init_params,
            time_span,
            args=(k1, k2, m1, m2, m3)
        ))

        time_choose = np.linspace(0, clip_n * n_time, n_time + 1).astype(int)[:-1]

        return traj[time_choose]

    trials_x = np.stack([generate_traj(n_time) for _ in tqdm(range(num), dynamic_ncols=True)], dtype=np.float32)

    if save_name.endswith('.npy'):
        save_name = save_name[:-4]

    # np.save(os.path.join(save_folder, save_name + '_x.npy'), trials_x)

if __name__ == '__main__':
    # np.random.seed(42)
    # create_3body(5e4, 10, 'train')

    np.random.seed(420)
    create_3body(5e3, 10, 'val')

    np.random.seed(42)
    create_3body(5e2, 500, 'dense_demo')
