import numpy as np
import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp

import os
from utils import to_pickle, from_pickle

m1 = 1.0
m2 = 1.0
g = 9.8

def potential_energy(states):
    _, y1, _, y2, _, _, _, _ = states
    return m1 * g * y1 + m2 * g * y2


def kinetic_energy(states):
    _, _, _, _, dx1, dy1, dx2, dy2 = states
    return 0.5 * (m1 * (dx1**2 + dy1**2) + m2 * (dx2**2 + dy2**2))


def update(t, state):
    assert len(state.shape) == 1
    l1, l2, t1, t2, dt1, dt2 = state
    cos_t1_t2, sin_t1_t2 = np.cos(t1 - t2), np.sin(t1 - t2)
    ddt1 = (m2 * g * np.sin(t2) * cos_t1_t2 - m2 * sin_t1_t2 * (l1 * dt1**2 * cos_t1_t2 + l2 * dt2**2) - (m1 + m2) * g * np.sin(t1)) \
        / (l1 * (m1 + m2 * sin_t1_t2**2))
    ddt2 = ((m1 + m2) * (l1 * dt1**2 * sin_t1_t2 - g * np.sin(t2) + g * np.sin(t1) * cos_t1_t2) + m2 * l2 * dt2**2 * sin_t1_t2 * cos_t1_t2) \
        / (l2 * (m1 + m2 * sin_t1_t2**2))
    ret = np.array([np.zeros_like(l1), np.zeros_like(l2), dt1, dt2, ddt1, ddt2])
    return ret


def get_orbit(state, n_steps, dt, update_fn=update, **kwargs):

    if not 'rtol' in kwargs.keys():
        kwargs['rtol'] = 1e-9

    t_eval = np.arange(n_steps + 1) * dt
    t_span = [0, t_eval[-1]]

    orbit_settings = locals()

    path = solve_ivp(fun=update_fn, t_span=t_span, y0=state.flatten(),
                     t_eval=t_eval, **kwargs)

    orbit = path['y'].reshape(6, n_steps + 1)
    return orbit, orbit_settings


def random_config(max_angle=0.5, max_velocity=0.1, rod_length_mean=1.0, rod_length_var=0.1):
    length_upper = np.random.uniform(rod_length_mean - rod_length_var, rod_length_mean + rod_length_var, (1,))
    length_lower = np.random.uniform(rod_length_mean - rod_length_var, rod_length_mean + rod_length_var, (1,))
    angles = np.random.uniform(-max_angle, max_angle, (2,))
    velocities = np.random.uniform(-max_velocity, max_velocity, (2,))
    state = np.concatenate([length_upper, length_lower, angles, velocities], axis=0)
    return state


def sample_orbits(n_steps, trials, dt=0.01, verbose=False, max_angle=0.5, max_velocity=0.1, **kwargs):

    orbit_settings = locals()
    if verbose:
        print("Making a dataset of double pendulum:")

    u, du = [], []
    settings = {}
    for _ in range(trials):
        state = random_config(max_angle=max_angle, max_velocity=max_velocity)
        # #state x #t_eval
        orbit, settings = get_orbit(state, n_steps=n_steps, dt=dt, **kwargs)
        # #t_eval x #state
        batch = orbit.transpose(1, 0)

        u.append(batch)
        du.append(np.array([update(None, s) for s in batch]))

    u = np.array(u)
    du = np.array(du)

    data = {
        'u': u,
        'dudt': du,
        # 'energies': energies,
        'dt': settings['dt'],
        't_eval': settings['t_eval'],
        'meta': orbit_settings,
    }
    return data


def transform_to_Catesian(data):
    l1, l2, t1, t2, dt1, dt2 = data['u'].transpose(2, 0, 1)
    _, _, _, _, ddt1, ddt2 = data['dudt'].transpose(2, 0, 1)
    sin_t1 = np.sin(t1)
    cos_t1 = np.cos(t1)
    sin_t2 = np.sin(t2)
    cos_t2 = np.cos(t2)
    x1 = l1 * sin_t1
    y1 = -l1 * cos_t1
    x2 = x1 + l2 * sin_t2
    y2 = y1 - l2 * cos_t2
    dx1 = l1 * cos_t1 * dt1
    dy1 = l1 * sin_t1 * dt1
    dx2 = dx1 + l2 * cos_t2 * dt2
    dy2 = dy1 + l2 * sin_t2 * dt2
    ddx1 = l1 * (-sin_t1 * dt1**2 + cos_t1 * ddt1)
    ddy1 = l1 * (+cos_t1 * dt1**2 + sin_t1 * ddt1)
    ddx2 = ddx1 + l2 * (-sin_t2 * dt2**2 + cos_t2 * ddt2)
    ddy2 = ddy1 + l2 * (+cos_t2 * dt2**2 + sin_t2 * ddt2)

    data['u'] = np.stack([x1, y1, x2, y2, dx1, dy1, dx2, dy2], axis=-1)
    data['dudt'] = np.stack([dx1, dy1, dx2, dy2, ddx1, ddy1, ddx2, ddy2], axis=-1)
    data['energies'] = get_energies(data['u'])
    return data

def make_orbits_dataset(trials=1000, test_trials=10, steps=500, test_steps=10000, **kwargs):
    data = sample_orbits(trials=trials, n_steps=steps,**kwargs)
    data_test = sample_orbits(trials=test_trials, n_steps=test_steps,**kwargs)
    data = transform_to_Catesian(data)
    data_test = transform_to_Catesian(data_test)
    # make a train/test split
    data['test'] = data_test
    return data


def get_dataset(name, save_dir, **kwargs):

    os.makedirs(f'{save_dir}/dataset') if not os.path.exists(f'{save_dir}/dataset') else None
    path = f'{save_dir}/dataset/{name}-dataset.pkl'

    try:
        data = from_pickle(path)
        print("Successfully loaded data from {}".format(path))
    except:
        np.random.seed(99)
        print("Had a problem loading data from {}. Rebuilding dataset...".format(path))
        data = make_orbits_dataset(**kwargs)
        to_pickle(data, path)
    return data


def get_energies(state):
    assert state.shape[-1] == 8, state.shape
    shape = state.shape[:-1]
    state = state.reshape(-1, 8).transpose(1, 0)
    ret = {}
    ret['potential_energy'] = potential_energy(state).reshape(shape)
    ret['kinetic_energy'] = kinetic_energy(state).reshape(shape)
    ret['energy'] = ret['potential_energy'] + ret['kinetic_energy']
    return ret

