import numpy as np
import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp

import os
from utils import to_pickle, from_pickle

def potential_energy(state):
    '''U=sum_i,j>i G m_i m_j / r_ij'''
    tot_energy = np.zeros((1, 1, state.shape[2]))
    for i in range(state.shape[0]):
        for j in range(i + 1, state.shape[0]):
            r_ij = ((state[i:i + 1, 1:3] - state[j:j + 1, 1:3])**2).sum(1, keepdims=True)**.5
            m_i = state[i:i + 1, 0:1]
            m_j = state[j:j + 1, 0:1]
            tot_energy += m_i * m_j / r_ij
    U = -tot_energy.sum(0).squeeze()
    return U


def kinetic_energy(state):
    '''T=sum_i .5*m*v^2'''
    energies = .5 * state[:, 0:1] * (state[:, 3:5]**2).sum(1, keepdims=True)
    T = energies.sum(0).squeeze()
    return T


def total_energy(state):
    return potential_energy(state) + kinetic_energy(state)


def get_accelerations(state, epsilon=0):
    # shape of state is [bodies x properties]
    net_accs = []  # [nbodies x 2]
    for i in range(state.shape[0]):  # number of bodies
        other_bodies = np.concatenate([state[:i, :], state[i + 1:, :]], axis=0)
        displacements = other_bodies[:, 1:3] - state[i, 1:3]  # indexes 1:3 -> pxs, pys
        distances = (displacements**2).sum(1, keepdims=True)**0.5
        masses = other_bodies[:, 0:1]  # index 0 -> mass
        pointwise_accs = masses * displacements / (distances**3 + epsilon)  # G=1
        net_acc = pointwise_accs.sum(0, keepdims=True)
        net_accs.append(net_acc)
    net_accs = np.concatenate(net_accs, axis=0)
    return net_accs


def update(t, state):
    state = state.reshape(-1, 5)  # [bodies, properties]
    deriv = np.zeros_like(state)
    deriv[:, 1:3] = state[:, 3:5]  # dx, dy = vx, vy
    deriv[:, 3:5] = get_accelerations(state)
    return deriv.reshape(-1)


def get_orbit(state, n_steps, dt, update_fn=update, **kwargs):
    if not 'rtol' in kwargs.keys():
        kwargs['rtol'] = 1e-9

    nbodies = state.shape[0]
    t_eval = np.arange(n_steps + 1) * dt
    t_span = [0, t_eval[-1]]

    orbit_settings = locals()
    orbit_settings['mass'] = state[:, 0]

    path = solve_ivp(fun=update_fn, t_span=t_span, y0=state.flatten(),
                     t_eval=t_eval, **kwargs)

    orbit = path['y'].reshape(nbodies, 5, n_steps + 1)
    return orbit, orbit_settings


def random_config():
    state = np.zeros((2, 5))
    state[:, 0] = 1.0  # mass

    # center position
    center_pos = np.random.normal(0., 0.0, 2)

    # center velocity
    center_vel = np.random.normal(0., 0.01, 2)

    # state generation
    angle = np.random.uniform(0., 2 * np.pi)
    r = np.random.uniform(0.5, 1.0,)
    pos = r * np.array([np.cos(angle), np.sin(angle)])

    # velocity that yields a nearly circular orbit
    vel_angle = angle + np.pi * ((np.random.randn() > 0.5) * 1.0 - 0.5 + np.random.normal(0, 0.05))
    vel_length = 1. / (2 * r**0.5) * np.random.normal(1., 0.05,)
    vel = vel_length * np.array([np.cos(vel_angle), np.sin(vel_angle)])

    # make the circular orbits SLIGHTLY elliptical
    state[0, 1:3] = pos + center_pos
    state[1, 1:3] = -pos + center_pos
    state[0, 3:5] = vel + center_vel
    state[1, 3:5] = -vel + center_vel
    return state


def sample_orbits(n_steps, trials, dt=0.01, nbodies=2, verbose=False, **kwargs):
    orbit_settings = locals()
    if verbose:
        print("Making a dataset of near-circular 2-body orbits:", flush=True)

    u, du = [], []
    settings = {}
    for _ in range(trials):
        state = random_config()
        # nbodies u #state u #t_eval
        orbit, settings = get_orbit(state, n_steps=n_steps, dt=dt, **kwargs)
        # #t_eval u (nbodies u (#state-1))
        batch = orbit.transpose(2, 0, 1).reshape(-1, nbodies * 5)
        u.append(orbit[:, 1:, :].transpose(2, 0, 1).reshape(n_steps + 1, nbodies * 4))
        du.append(np.array([update(None, s) for s in batch])
                  .reshape(n_steps + 1, nbodies, 5)[:, :, 1:].reshape(n_steps + 1, nbodies * 4))

    # qqpp,qqpp -> qqqqpppp
    u = np.array(u).reshape(trials, n_steps + 1, nbodies, 4).transpose(0, 1, 3, 2).reshape(trials, n_steps + 1, 4 * nbodies)
    du = np.array(du).reshape(trials, n_steps + 1, nbodies, 4).transpose(0, 1, 3, 2).reshape(trials, n_steps + 1, 4 * nbodies)

    energies = get_energies(u)

    data = {
        'u': u,
        'dudt': du,
        'energies': energies,
        'dt': settings['dt'],
        't_eval': settings['t_eval'],
        'meta': orbit_settings,
    }
    return data


def make_orbits_dataset(trials=1000, test_trials=10, steps=500, test_steps=10000, **kwargs):
    data = sample_orbits(trials=trials, n_steps=steps, **kwargs)
    data_test = sample_orbits(trials=test_trials, n_steps=test_steps, **kwargs)
    data['test'] = data_test
    return data


def get_dataset(name, save_dir, **kwargs):

    os.makedirs(f'{save_dir}/dataset') if not os.path.exists(f'{save_dir}/dataset') else None
    path = f'{save_dir}/dataset/{name}-dataset.pkl'

    try:
        data = from_pickle(path)
        print("Successfully loaded data from {}".format(path))
    except:
        np.random.seed(99)
        print("Had a problem loading data from {}. Rebuilding dataset...".format(path), flush=True)
        data = make_orbits_dataset(**kwargs)
        to_pickle(data, path)
    return data


def get_energies(state):
    shape = state.shape[:-1]
    # b x (qqqqpppp) -> b x (qqpp,qqpp)
    state = state.reshape(-1, 4, 2).transpose(0, 2, 1)
    # b x (qqpp,qqpp) -> (qqpp,qqpp) x b
    state = np.concatenate([np.ones((state.shape[0], 2, 1)), state], axis=2).transpose(1, 2, 0)
    ret = {}
    ret['potential_energy'] = potential_energy(state).reshape(shape)
    ret['kinetic_energy'] = kinetic_energy(state).reshape(shape)
    ret['energy'] = ret['potential_energy'] + ret['kinetic_energy']
    return ret


