# https://github.com/sdascoli/odeformer/blob/main/odeformer/odebench/solve_and_plot.py

from copy import deepcopy
import json

from matplotlib import pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp
import sympy as sp
import torch
from tqdm import tqdm

from odebench.strogatz_equations import equations
from solver.ode_forward import ode_forward


config = {
    "t_span": (0, 10),  # time span for integration
    "method": "LSODA",  # method for integration
    "rtol": 1e-5,  # relative tolerance (let's be strict)
    "atol": 1e-7,  # absolute tolerance (let's be strict)
    "first_step": 1e-6,  # initial step size (let's be strict)
    "t_eval": np.linspace(0, 10, 150),  # output times for the solution
    "min_step": 1e-10,  # minimum step size (only for LSODA)
}


def process_equations(equations):
    """Create sympy expressions for each of the equations (and their different parameter values).
    We directly add the list of expressions to each dictionary.
    """

    def convert_inplace(nested_list: list | tuple, subs):
        for i in range(len(nested_list)):
            if isinstance(nested_list[i], list | tuple):
                convert_inplace(nested_list[i], subs)
            else:
                nested_list[i] = sp.sympify(nested_list[i]).subs(subs)
        return nested_list

    for eq_dict in equations:
        """For a given equation, create sympy expressions where the different parameter values have been substituted in."""
        const_symbols = sp.symbols([f'c_{i}' for i in range(len(eq_dict['consts'][0]))])
        eq_dict['substituted'] = [
            [sp.sympify(eq).subs(zip(const_symbols, consts)) for eq in eq_dict['eq'].split('|')]
            for consts in eq_dict['consts']
        ]
        eq_dict['substituted_gt'] = [
            convert_inplace(eq_dict['eq_gt'].split('|'), dict(zip(const_symbols, consts)))
            for consts in eq_dict['consts']
        ]
        eq_dict['substituted_mnn'] = [
            convert_inplace(eq_dict['mnn'], dict(zip(const_symbols, consts)))
            for consts in eq_dict['consts']
        ]
    print("PROCESSING DONE")


def solve_equations(equations, config):
    """Solve all equations for a given config.
    
    We add the solutions to each of the equations dictionary as a list of list of solution dictionaries.
    The list of list represents (number of parameter settings x number of initial conditions).
    """
    for eq_dict in tqdm(equations):
        eq_dict['solutions'] = []
        var_symbols = sp.symbols([f'x_{i}' for i in range(eq_dict['dim'])])
        for i, fns in enumerate(eq_dict['substituted']):
            eq_dict['solutions'].append([])
            callable_fn = lambda t, x: np.array([sp.lambdify(var_symbols, eq, 'numpy')(*x) for eq in fns])
            for initial_conditions in eq_dict['init']:
                sol = solve_ivp(callable_fn, **config, y0=initial_conditions)
                if sol.status != 0:
                    print(f"Error in equation {eq_dict['id']}: {eq_dict['eq_description']}, constants {i}, initial conditions {initial_conditions}: {sol.message}")
                eq_dict['solutions'][i].append({
                    "success": sol.success,
                    "message": sol.message,
                    "t": sol.t.tolist(),
                    "y": sol.y.tolist(),
                    "nfev": int(sol.nfev),
                    "njev": int(sol.njev),
                    "nlu": int(sol.nlu),
                    "status": int(sol.status),
                    'consts': eq_dict['consts'][i],
                    'init': initial_conditions,
                })

        eq_dict['solutions_gt'] = []
        init_symbols = sp.symbols([f'i_{i}' for i in range(eq_dict['dim'])])
        t = np.arange(0., 10., 1e-2)
        for i, fns in enumerate(eq_dict['substituted_gt']):
            eq_dict['solutions_gt'].append([])
            for initial_conditions in eq_dict['init']:
                y = np.array([sp.lambdify('t', eq.subs(zip(init_symbols, initial_conditions)), 'numpy')(t) for eq in fns])
                y1 = np.array([sp.lambdify('t', sp.diff(eq, sp.symbols('t'), 1).subs(zip(init_symbols, initial_conditions)), 'numpy')(t) for eq in fns])
                y2 = np.array([sp.lambdify('t', sp.diff(eq, sp.symbols('t'), 2).subs(zip(init_symbols, initial_conditions)), 'numpy')(t) for eq in fns])
                eq_dict['solutions_gt'][i].append({
                    "t": t.tolist(),
                    "y": y.tolist(),
                    "y1": y1.tolist(),
                    "y2": y2.tolist(),
                    'consts': eq_dict['consts'][i],
                    'init': initial_conditions,
                })

        eq_dict['solutions_mnn'] = []
        dtype = torch.float64
        device = torch.device('cpu')
        for i, fns in enumerate(eq_dict['substituted_mnn']):
            eq_dict['solutions_mnn'].append([])
            n_orders = 3
            coefficients = torch.tensor([eq_dict['mnn'][0]], dtype=dtype, device=device)
            coefficients = torch.cat([coefficients, *([torch.zeros_like(coefficients[..., :1])] * (n_orders - coefficients.size(-1)))], dim=-1)
            rhs_equation = torch.tensor([eq_dict['mnn'][1]], dtype=dtype, device=device)
            step_length = 1e-2
            steps = torch.tensor([step_length], dtype=dtype, device=device)
            n_steps = int(10. / step_length)
            for initial_conditions in eq_dict['init']:
                init_vars = torch.tensor(initial_conditions, dtype=dtype, device=device)[None, ..., None]
                y = ode_forward(coefficients, rhs_equation, init_vars, steps, n_steps=n_steps, enable_central_smoothness=False).permute(1, 2, 0)
                eq_dict['solutions_mnn'][i].append({
                    "t": np.arange(0., 10., step_length).tolist(),
                    "y": y.tolist(),
                    'consts': eq_dict['consts'][i],
                    'init': initial_conditions,
                })

    print("SOLVING DONE")


def save_to_disk(equations, filename):
    """Save the equations (including substituted sympy expressions) to disk"""
    store = deepcopy(equations)
    # Can't serialize sympy expressions, so convert them to strings
    for eq_dict in store:
        for i, fns in enumerate(eq_dict['substituted']):
            for j, fn in enumerate(fns):
                eq_dict['substituted'][i][j] = str(fn)
        for i, fns in enumerate(eq_dict['substituted_gt']):
            for j, fn in enumerate(fns):
                eq_dict['substituted_gt'][i][j] = str(fn)
    with open(filename, 'w') as f:
        json.dump(store, f)
    print("SAVING DONE")


def plot_all_equations(equations):
    plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

    experiment_titles = ['RC Circuit', 'Population', 'Language Death', 'Harmonic', 'Harmonic Damping']

    rows = 3
    cols = len(experiment_titles) + 1
    const_i, init_i = 0, 0
    fig, axs = plt.subplots(rows, cols, figsize=(1.5 * cols, 1. * rows), sharex=True)

    # Loop over all equations
    for j, eq_dict in enumerate(tqdm(equations)):

        axs[0, j].set_title(experiment_titles[j])
        axs[-1, j].set_xlabel('$t$')
        axs[-1, j].set_xlim(0., 10.)

        solution_mnn = eq_dict['solutions_mnn'][const_i][init_i]
        solution_gt = eq_dict['solutions_gt'][const_i][init_i]

        for i in range(rows):
            axs[i, j].plot(solution_mnn['t'], solution_mnn['y'][0][i], label='S-MNN Solution', color=plot_colors[0], linestyle='-', linewidth=2., alpha=1.)
            axs[i, j].plot(solution_gt['t'], solution_gt[f'y{i if i > 0 else ""}'][0], label='Closed-Form Solution', color=plot_colors[3], linestyle=':', linewidth=3., alpha=1.)
            diff = np.array(solution_mnn['y'][0][i]) - np.array(solution_gt[f'y{i if i > 0 else ""}'][0])
            print(j, i, np.mean(np.abs(diff) ** 2.))
            axs[i, j].grid()
            axs[i, j].set_facecolor((1., 1., 1., 1.))
            axs[i, j].xaxis.set_ticks_position('none')
            axs[i, j].yaxis.set_ticks_position('none')

    for i in range(rows):
        axs[i, 0].set_ylabel(f'$y{"'"*i}(t)$')

    step_length = 1e-2
    t = np.arange(0., 10., step_length)
    initial_conditions = [0., -1., 1.]
    y_mnn = ode_forward(
        torch.tensor([[[[0., 1., 1., 1.]]]], dtype=torch.float64),
        torch.tensor([[0.]], dtype=torch.float64),
        torch.tensor([[initial_conditions]], dtype=torch.float64),
        torch.tensor([step_length], dtype=torch.float64),
        n_steps=int(10. / step_length),
        enable_central_smoothness=False,
    ).permute(1, 2, 0)

    eq_gt = 'i_0 + i_1 + i_2 + exp(-.5 * t) * (-(i_1 + i_2) * cos(sqrt(.75) * t) + (i_1 - i_2) / sqrt(3.) * sin(sqrt(.75) * t))'
    init_symbols = ['i_0', 'i_1', 'i_2']
    fns = [sp.sympify(eq_gt)]
    y = np.array([sp.lambdify('t', eq.subs(zip(init_symbols, initial_conditions)), 'numpy')(t) for eq in fns])
    y1 = np.array([sp.lambdify('t', sp.diff(eq, sp.symbols('t'), 1).subs(zip(init_symbols, initial_conditions)), 'numpy')(t) for eq in fns])
    y2 = np.array([sp.lambdify('t', sp.diff(eq, sp.symbols('t'), 2).subs(zip(init_symbols, initial_conditions)), 'numpy')(t) for eq in fns])
    solution_gt = {'y': y, 'y1': y1, 'y2': y2}

    j = 5
    # axs[0, j].set_title("$y'(t)+y''(t)+y'''(t)=0$")
    axs[0, j].set_title("Third-Order")
    axs[-1, j].set_xlabel('$t$')
    axs[-1, j].set_xlim(0., 10.)

    for i in range(rows):
        axs[i, j].plot(t, y_mnn[0][i], color=plot_colors[0], label='S-MNN Solution', linestyle='-', linewidth=2., alpha=1.)
        axs[i, j].plot(t, solution_gt[f'y{i if i > 0 else ""}'][0], label='Closed-Form Solution', color=plot_colors[3], linestyle=':', linewidth=3., alpha=1.)
        diff = np.array(y_mnn[0][i]) - np.array(solution_gt[f'y{i if i > 0 else ""}'][0])
        print(j, i, np.mean(np.abs(diff) ** 2.))
        axs[i, j].grid()
        axs[i, j].set_facecolor((1., 1., 1., 1.))
        axs[i, j].xaxis.set_ticks_position('none')
        axs[i, j].yaxis.set_ticks_position('none')

    handles, labels = axs[0, j].get_legend_handles_labels()
    fig.legend(handles[::-1], labels[::-1], loc='lower center', ncol=2, framealpha=1., bbox_to_anchor=(.5, -.05))
    fig.set_facecolor((1., 1., 1., 0.))
    fig.tight_layout()
    # fig.subplots_adjust(hspace=0.)

    plot_filename = 'odebench/odebench_trajectories.pdf'
    fig.savefig(plot_filename, format='pdf', bbox_inches='tight', pad_inches=.01, transparent=False)
    plt.show()
    plt.close(fig)
    print("PLOTTING ALL EQUATIONS DONE")


if __name__ == '__main__':
    process_equations(equations)
    solve_equations(equations, config)
    plot_all_equations(equations)
