import pickle
import warnings
import numpy as np
from tqdm import tqdm
from scipy.integrate import odeint

from itertools import combinations

warnings.filterwarnings('ignore')

def eval_sample(sample_path: str) -> np.ndarray:
    if sample_path.endswith('.pkl'):
        with open(sample_path, 'rb') as f:
            data_x = pickle.load(f)
    else:
        data_x = np.load(sample_path)

    momentum_x = data_x[:, :, 9:].reshape(*data_x.shape[:2], 3, 3)
    momentum_x = momentum_x.sum(axis=2).transpose(0, 2, 1)
    momentum_error = np.square(momentum_x - momentum_x.mean(axis=-1, keepdims=True)).mean()

    n_time = data_x.shape[1]
    clip_n = 50

    m1, m2, m3, m = 1.0, 1.0, 1.0, 1.0
    G = 6.67408e-11
    m_nd = 1.989e+30  
    r_nd = 5.326e+12
    v_nd = 30000
    t_nd = 79.91 * 365 * 24 * 3600 * 0.51

    k1 = G * t_nd * m_nd / (np.square(r_nd) * v_nd)
    k2 = v_nd * t_nd / r_nd

    def threeBodyEquations(w, t, k1, k2, m1, m2, m3):
        r1, r2, r3, v1, v2, v3 = w.reshape(6, 3)

        r12 = np.linalg.norm(r2 - r1)
        r13 = np.linalg.norm(r3 - r1)
        r23 = np.linalg.norm(r3 - r2)
    
        dv1bydt = k1 * m2 * (r2 - r1) / r12**3 + k1 * m3 * (r3 - r1) / r13**3
        dv2bydt = k1 * m1 * (r1 - r2) / r12**3 + k1 * m3 * (r3 - r2) / r23**3
        dv3bydt = k1 * m1 * (r1 - r3) / r13**3 + k1 * m2 * (r2 - r3) / r23**3
        dr1bydt = k2 * v1
        dr2bydt = k2 * v2
        dr3bydt = k2 * v3

        r_derivs = np.concatenate([dr1bydt, dr2bydt, dr3bydt])
        v_derivs = np.concatenate([dv1bydt, dv2bydt, dv3bydt])
        return np.concatenate([r_derivs, v_derivs])
        
    pred_n = 3 + 1
    last_t = 5
    pred_time = np.linspace(0, last_t, clip_n * n_time + 1)[:clip_n*pred_n+1]
    time_choose = np.linspace(0, clip_n * n_time, n_time + 1).astype(int)[:pred_n]

    traj_errors, energy_errors = [], []
    for sample in tqdm(data_x, dynamic_ncols=True):
        traj_error = []
        for i in range(0, n_time - pred_n):
            init_params = sample[i]
            traj = np.array(odeint(
                threeBodyEquations,
                init_params,
                pred_time,
                args=(k1, k2, m1, m2, m3)
            ))[time_choose]
            traj_error.append(np.square(traj[1:] - sample[i+1:i+pred_n]))
        traj_errors.append(np.stack(traj_error))


        traj = sample[:, :9].reshape(-1, 3, 3)
        vel = sample[:, 9:].reshape(-1, 3, 3)

        mutual_traj = np.stack(list(combinations(traj.transpose(1, 0, 2), 2)))
        mutual_traj = mutual_traj.transpose(1, 0, 2, 3)
        distances = np.linalg.norm(mutual_traj[0] - mutual_traj[1], axis=-1)
        grav_energy = (- G * (m**2) * m_nd / (distances * r_nd)).sum(axis=0)
        vel_energy = 0.5 * m * np.square(vel * v_nd).sum(axis=-1).sum(axis=-1)
        total_energy = grav_energy + vel_energy

        energy_errors.append(np.square((total_energy - total_energy.mean()) / 1e8).mean())

    # traj_errors.shape == (batch_size, 12, 3, 18)

    with open('/'.join(sample_path.split('/')[:-1]) + '/results.pkl', 'wb') as f:
        pickle.dump({
            'traj error': np.stack(traj_errors),
            'energy error': np.stack(energy_errors),
            'momentum error': momentum_error
        }, f)


if __name__=='__main__':
    eval_sample('logs/3body---PhyGRU---hidden_size-512--bidirectional-True--n_layers-3--t_embed_size-128---2024_06_17__07_46_29/ode_sample-2024_06_17__13_55_52/samples_all.pkl')