import numpy as np
import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp

import os
from utils import to_pickle, from_pickle

L = 1 / 0.08
R = 0.8
C = 1.0
E = -0.7
D = lambda v: v**3 / 3 - v
dDdV = lambda v: v**2 - 1
n_state = 3

def update(t, state):
    V, W, I = state
    dV = (-D(V) - W + I) / C
    dW = (V - W * R - E) / L
    return np.array([dV, dW, 0.0])


def update2(t, state):
    V_C, V_L, I_C, I_L, I = state
    V = V_C
    W = I_L
    dV = (-D(V) - W + I) / C
    dW = (V - W * R - E) / L
    dV_C = dV
    dV_L = dV - dW * R  # = dV-dV_R
    dI_C = -dDdV(V) * dV - dW
    dI_L = dW
    return np.array([dV_C, dV_L, dI_C, dI_L, 0.0])


def get_orbit(state, n_steps, dt, update_fn=update, **kwargs):

    if not 'rtol' in kwargs.keys():
        kwargs['rtol'] = 1e-9

    t_eval = np.arange(n_steps + 1) * dt
    t_span = [0, t_eval[-1]]

    orbit_settings = locals()

    path = solve_ivp(fun=update_fn, t_span=t_span, y0=state.flatten(),
                     t_eval=t_eval, **kwargs)

    orbit = path['y'].reshape(n_state, n_steps + 1)
    return orbit, orbit_settings


##### INITIALIZE THE DOUBLE PENDULUM #####
def random_config():
    V = np.random.uniform(-1.5, 1.5)
    W = np.random.uniform(0.0, 2.0)
    I = np.random.uniform(0.7, 1.1)
    return np.stack([V, W, I])


def sample_orbits(n_steps, trials, dt=0.1, verbose=False, **kwargs):

    orbit_settings = locals()
    if verbose:
        print("Making a dataset of FitzHugh-Nagumo:")

    u, du = [], []
    settings = {}
    for _ in range(trials):
        state = random_config()
        # #state x #t_eval
        orbit, settings = get_orbit(state, n_steps=n_steps, dt=dt, **kwargs)
        # #t_eval x #state
        batch = orbit.transpose(1, 0)

        u.append(batch)
        du.append(np.array([update(None, s) for s in batch]))

    u = np.array(u)
    du = np.array(du)

    data = {
        'u': u,
        'dudt': du,
        # 'energies': energies,
        'dt': settings['dt'],
        't_eval': settings['t_eval'],
        'meta': orbit_settings,
    }
    return data

def transform_to_Dirac(data):
    V, W, I = data['u'].transpose(2, 0, 1)
    dV, dW, _ = data['dudt'].transpose(2, 0, 1)
    V_C = V
    V_L = V - E - W * R
    I_C = I - D(V) - W  # = I - I_D - W
    I_L = W
    dV_C = dV
    dV_L = dV - dW * R  # = dV-dV_R
    dI_C = -dDdV(V) * dV - dW
    dI_L = dW

    data['u'] = np.stack([V_C, V_L, I_C, I_L], axis=-1)
    data['dudt'] = np.stack([dV_C, dV_L, dI_C, dI_L], axis=-1)
    data['energies'] = get_energies(data['u'])

    return data

def make_orbits_dataset(trials=1000, test_trials=10, steps=500, test_steps=2000, **kwargs):
    data = sample_orbits(trials=trials, n_steps=steps,**kwargs)
    data_test = sample_orbits(trials=test_trials, n_steps=test_steps,**kwargs)
    data = transform_to_Dirac(data)
    data_test = transform_to_Dirac(data_test)
    data['test'] = data_test
    return data


def get_dataset(name, save_dir, **kwargs):
    '''Returns an orbital dataset. Also constructs
    the dataset if no saved version is available.'''

    os.makedirs(f'{save_dir}/dataset') if not os.path.exists(f'{save_dir}/dataset') else None
    path = f'{save_dir}/dataset/{name}-dataset.pkl'

    try:
        data = from_pickle(path)
        print("Successfully loaded data from {}".format(path))
    except:
        np.random.seed(99)
        print("Had a problem loading data from {}. Rebuilding dataset...".format(path))
        data = make_orbits_dataset(**kwargs)
        to_pickle(data, path)
    return data


def get_energy(state):
    return np.zeros_like(state)

def get_energies(state):
    shape = state.shape[:-1]
    V_C, V_L, I_C, I_L = state.reshape(-1, 4).T
    ret = {}
    ret['energy'] = get_energy(V_C).reshape(shape)
    return ret

