import os
import sys
import pickle
import numpy as np
from tqdm import tqdm

from copy import deepcopy as copy

def main(path: str, num_atoms: int) -> None:
    inter_strength = 0.1
    delta_T = 0.001
    total_t, sample_freq = 5100, 100
    skip_n, pred_n = 10, 8
    assert skip_n > pred_n + 1
    test_edge = np.load(f'data/val_edge_5_spring.npy')

    if path.endswith('.pkl'):
        with open(path, 'rb') as f:
            test_x = pickle.load(f)
    else:
        test_x = np.load(path)

    batch_size, n_time = test_x.shape[:2]

    locs, vels = test_x.reshape(batch_size, n_time, 2, 2, num_atoms).transpose(3, 0, 1, 2, 4)

    pred_errors = []
    for edges, gen_loc, gen_vel in tqdm(zip(test_edge, locs, vels), total=len(locs)):
        # edges.shape == (num_atoms, num_atoms)
        # gen_loc.shape, gen_vel.shape == (n_time, 2, num_atoms)
        sample_error = []
        for t in range(0, n_time - skip_n, skip_n):
            done_n = 0
            gt_loc, gt_vel = [], []

            loc_next, vel_next = copy(gen_loc[t]), copy(gen_vel[t])

            with np.errstate(divide='ignore'):
                forces_size = - inter_strength * edges
                np.fill_diagonal(forces_size, 0)  # self forces are zero (fixes division by zero)
                F = (forces_size.reshape(1, num_atoms, num_atoms) *
                    np.concatenate((
                        np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, num_atoms, num_atoms),
                        np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, num_atoms, num_atoms)
                    ))).sum(axis=-1)

                vel_next += delta_T * F

                for i in range(1, total_t):
                    loc_next += delta_T * vel_next

                    if i % sample_freq == 0:
                        gt_loc.append(copy(loc_next))
                        gt_vel.append(copy(vel_next))
                        done_n += 1
                        if done_n == pred_n:
                            break

                    forces_size = - inter_strength * edges
                    np.fill_diagonal(forces_size, 0)

                    F = (
                        forces_size.reshape(1, num_atoms, num_atoms) * \
                        np.concatenate((
                            np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(1, num_atoms, num_atoms),
                            np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(1, num_atoms, num_atoms)
                        ))
                    ).sum(axis=-1)

                    vel_next += delta_T * F
            

            gt_loc, gt_vel = np.stack(gt_loc), np.stack(gt_vel)

            sample_error.append(
                np.square(np.concatenate([
                    gen_loc[t+1:t+1+pred_n] - gt_loc,
                    gen_vel[t+1:t+1+pred_n] - gt_vel,
                ], axis=-1)).mean()
            )

        pred_errors.append(np.stack(sample_error))
    pred_errors = np.stack(pred_errors)

    loc_feature, vel_feature = test_x.reshape(batch_size, n_time, 2, 2, num_atoms).transpose(3, 0, 1, 4, 2)
    motum = vel_feature.sum(axis=2)
    motum_error = np.square(motum - motum.mean(axis=1, keepdims=True)).mean()

    r2 = np.square(np.expand_dims(loc_feature, axis=3) - np.expand_dims(loc_feature, axis=2))
    p_energy = test_edge.reshape(*test_edge.shape, 1, 1) * r2.transpose(0, 2, 3, 4, 1)
    p_energy = 0.25 * 0.1 * p_energy.sum(axis=1).sum(axis=1).sum(axis=1)
    k_energy = 0.5 * np.square(vel_feature).sum(axis=-1).sum(axis=-1)
    total_energy = k_energy + p_energy
    energy_error = np.square(total_energy - total_energy.mean(axis=1, keepdims=True)).mean()

    with open('/'.join(path.split('/')[:-1]) + '/results.pkl', 'wb') as f:
        pickle.dump({
            'pred error': pred_errors,
            'motum error': motum_error,
            'energy error': energy_error
        }, f)



if __name__ == '__main__':
    root = 'logs'
    num_atoms = 5

    try:
        path = sys.argv[1]
    except:
        path = 'data/val_x_5_spring.npy'

    main(path, num_atoms)


    # for model_name in os.listdir(root):
    #     for model_hyper in os.listdir(os.path.join(root, model_name)):
    #         for file in os.listdir(os.path.join(root, model_name, model_hyper)):
    #             if file.__contains__('sample'):
    #                 pkl_folder = os.path.join(root, model_name, model_hyper, file)
    #                 if os.path.exists(os.path.join(pkl_folder, 'results.pkl')):
    #                     continue
    #                 pkl_path = os.path.join(pkl_folder, 'samples_all.pkl')
    #                 if os.path.exists(pkl_path):
    #                     main(pkl_path, num_atoms)

