import os
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

warnings.simplefilter(action='ignore', category=FutureWarning)
from datafold import InitialCondition, TSCDataFrame
from datafold.appfold.mpc import LQR
from datafold.utils._systems import VanDerPol
from swimnetworks import Dense, Linear
from sklearn.pipeline import Pipeline
from kirnn import KIRNN

control_path = "control"
if not os.path.exists(control_path):
    os.mkdir(control_path)

seeds = [1, 2, 3, 4, 5]
# seeds = [5]

n_timeseries = 150  # number of timeseries in training set
n_timesteps = 50  # how many timesteps for every time series
dt = 0.05  # delta time
regularization_scale_control_inverse = 1e-6
layer_width = 128
layer_width_control = 32

norm_pred = []
costs = []
for seed in seeds:
    rng = np.random.default_rng(seed)
    vdp = VanDerPol(control_coord="y")

    time_values = np.arange(0, n_timesteps * dt, dt)

    X_ic = rng.uniform(-3.0, 3.0, size=(n_timeseries, 2))
    idx = pd.MultiIndex.from_arrays([np.arange(n_timeseries), np.zeros(n_timeseries)])
    X_ic = TSCDataFrame(X_ic, index=idx, columns=vdp.feature_names_in_)
    U_tsc = rng.uniform(-3.0, 3.0, size=(n_timeseries, 1, 1))
    U_tsc = np.tile(U_tsc, (1, n_timesteps - 1, 1))
    U_tsc = TSCDataFrame.from_tensor(
        U_tsc,
        time_series_ids=X_ic.ids,
        feature_names=vdp.control_names_in_,
        time_values=time_values[:-1],
    )
    X_tsc, U_tsc = vdp.predict(X_ic, U=U_tsc)

    # define the RNN acting on the states and the controls
    steps = [
        (
            "hidden",
            Dense(
                layer_width=layer_width,
                activation="tanh",
                parameter_sampler="tanh",
                random_seed=rng.integers(99999999999, size=1),
            ),
        ),
    ]
    network_dictionary = Pipeline(steps=steps)
    kirnn_model = KIRNN(network_dictionary, n_features_in=2, control=True)

    # define a separate nonlinear network that converts the original controls to nonlinear, highdimensional controls.
    steps = [
        (
            "hidden",
            Dense(
                layer_width=layer_width_control,
                activation="tanh",
                parameter_sampler="tanh",
                sample_uniformly=True,
                random_seed=rng.integers(9999999999, size=1),
            ),
        ),
    ]
    control_dictionary = Pipeline(steps=steps)
    control_dictionary_inverse = Pipeline(
        steps=[
            (
                "linear",
                Linear(regularization_scale=regularization_scale_control_inverse),
            )
        ]
    )
    # We sample the neurons uniformly, i.e. we do not need any "output" function to bias the sampling.
    U_tsc_numpy = U_tsc.to_numpy()
    U_tsc_nonlinear = control_dictionary.fit_transform(U_tsc_numpy, U_tsc_numpy)
    control_dictionary_inverse.fit(U_tsc_nonlinear, U_tsc_numpy)

    # convert back to TSCDataFrame
    U_tsc_nonlinear_list = []
    current_idx = 0
    for ts_id in U_tsc.ids:
        timeseries = U_tsc.loc[ts_id]
        U_tsc_nonlinear_list.append(
            U_tsc_nonlinear[current_idx:(current_idx + timeseries.shape[0]), :]
        )
        current_idx += timeseries.shape[0]
    U_tsc_nonlinear = TSCDataFrame.from_tensor(
        np.array(U_tsc_nonlinear_list),
        time_series_ids=U_tsc.ids,
        feature_names=[f'neuron{k}' for k in range(U_tsc_nonlinear.shape[1])],
        time_values=U_tsc.time_values(),
    )

    start_time_fit = time.time()
    kirnn_model.fit(X_tsc, U=U_tsc_nonlinear)
    end_time_fit = time.time()
    print(f"Fit time = {end_time_fit - start_time_fit}s.")

    # number of time steps and time values for controlled time series
    n_timesteps_oos = 200
    time_values_oos = np.linspace(
        0, n_timesteps_oos * X_tsc.delta_time, n_timesteps_oos
    )

    # random initial state
    X_ic_oos = np.array([[-1.5, -1]])
    X_ic_oos = InitialCondition.from_array(
        X_ic_oos, feature_names=["0", "1"], time_value=0
    )

    # define target state
    target_state = InitialCondition.from_array(
        np.array([0, 0]), feature_names=["x1", "x2"], time_value=0
    )

    # check if predictions are reasonable
    time_values_dev = np.arange(0, 1000 * dt, dt)
    U_pred = 2 * np.ones(shape=(len(time_values_dev), 1))
    U_pred = TSCDataFrame.from_array(
        U_pred, feature_names=vdp.control_names_in_, time_values=time_values_dev
    )
    U_pred_nonlinear = control_dictionary.transform(U_pred)
    U_pred_nonlinear = TSCDataFrame.from_array(
        U_pred_nonlinear.to_numpy(), feature_names=list(U_tsc_nonlinear.columns), time_values=time_values_dev
    )

    # predict state evolution only
    X_pred = kirnn_model.predict(X_ic_oos, U=U_pred_nonlinear)
    X_true, U_true = vdp.predict(X_ic_oos, U=U_pred)

    plt.plot(
        X_true.loc[:, "x1"].to_numpy(), X_true.loc[:, "x2"].to_numpy(), label="true"
    )
    plt.plot(
        X_pred.loc[:, "0"].to_numpy(),
        X_pred.loc[:, "1"].to_numpy(),
        label="prediction",
    )
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.title("Predictions with KIRNN")
    plt.savefig(os.path.join(control_path, "VdP_control_nonlinear_predictions.pdf"))
    plt.show()

    cost_running_array = np.zeros(shape=(layer_width + 2))
    cost_input = (1.0 / layer_width_control)
    cost_running_array = np.ones(shape=layer_width)
    cost_running_array[0] = 10
    cost_running_array[1] = 10
    # cost_running_array = np.ones(shape=layer_width)

    start_time_lqr_optimize = time.time()
    lqr = LQR(
        edmd=kirnn_model.edmd, cost_running=cost_running_array, cost_input=cost_input
    )
    end_time_lqr_optimize = time.time()
    print(f"LQR optimization time={end_time_lqr_optimize - start_time_lqr_optimize}s.")
    lqr.preset_target_state(target_state)

    # allocate data structures and fill in the following system loop
    X_oos = TSCDataFrame.from_array(
        np.zeros((n_timesteps_oos, 2)),
        feature_names=vdp.feature_names_in_,
        time_values=time_values_oos,
    )
    U_oos = TSCDataFrame.from_array(
        np.zeros((n_timesteps_oos - 1, 1)),
        feature_names=vdp.control_names_in_,
        time_values=time_values_oos[:-1],
    )
    U_oos_nonlinear = control_dictionary.transform(U_oos)

    X_oos.iloc[0, :] = X_ic_oos.to_numpy()

    for i in tqdm(range(1, n_timesteps_oos)):
        state = X_oos.iloc[[i - 1], :]
        U_oos_nonlinear.iloc[i - 1, :] = lqr.control_sequence(X=state)
        U_oos.iloc[i - 1, :] = control_dictionary_inverse.transform(
            U_oos_nonlinear.iloc[i - 1, :]
        )
        new_state, _ = vdp.predict(
            state, U=U_oos.iloc[[i - 1], :], time_values=time_values_oos[i - 1: i + 1]
        )
        X_oos.iloc[i, :] = new_state.iloc[[1], :].to_numpy()

    # calculate cost
    from datafold.appfold.mpc import _cost_to_array

    cost_diagonal = _cost_to_array(
        lqr.cost_running, n_elements=lqr.edmd.n_features_out_
    )
    Q = np.diag(cost_diagonal)
    lifted_target = lqr.edmd.transform(target_state).to_numpy()
    lifted_state = lqr.edmd.transform(X_oos.iloc[[-1], :]).to_numpy()
    cost = (lifted_state - lifted_target) @ Q @ (lifted_target - lifted_target).T
    for i in range(1, n_timesteps_oos - 1):
        lifted_state = lqr.edmd.transform(X_oos.iloc[[i - 1], :]).to_numpy()
        cost += (lifted_state - lifted_target) @ Q @ (
                lifted_target - lifted_target
        ).T + U_oos.iloc[i - 1, :].to_numpy() ** 2 * lqr.cost_input
    print("cost=", cost[0][0])

    trajectory_uncontrolled, _ = vdp.predict(
        X_ic_oos, U=np.zeros((n_timesteps_oos - 1)), time_values=time_values_oos
    )

    plt.figure(figsize=(5, 3))
    plt.plot(
        X_oos.loc[:, "x1"].to_numpy(),
        X_oos.loc[:, "x2"].to_numpy(),
        c="red",
        label="controlled traj.",
    )
    plt.quiver(
        *X_oos.to_numpy()[:-1, :].T,
        *np.column_stack(
            [np.zeros_like(U_oos.to_numpy()), U_oos.to_numpy() / X_oos.delta_time]
        ).T,
        color="blue",
        label="control",
    )
    plt.plot(X_oos.iloc[0, 0], X_oos.iloc[0, 1], "o", c="red")
    plt.plot(
        trajectory_uncontrolled.loc[:, "x1"].to_numpy(),
        trajectory_uncontrolled.loc[:, "x2"].to_numpy(),
        c="black",
        label="uncontrolled traj.",
    )
    plt.plot(
        trajectory_uncontrolled.iloc[0, 0],
        trajectory_uncontrolled.iloc[0, 1],
        "o",
        c="black",
        label="initial state",
    )
    plt.plot(
        target_state.iloc[0, 0],
        target_state.iloc[0, 1],
        "*",
        c="black",
        label="target state",
    )
    plt.xlabel("$y_1$")
    plt.ylabel("$y_2$")
    plt.legend(loc="lower right")
    plt.grid()
    plt.tight_layout()
    plt.savefig(os.path.join(control_path, "VdP_control_nonlinear.pdf"))
    plt.show()

    plt.figure(figsize=(10, 7))
    plt.plot(
        X_oos.time_values(), X_oos.loc[:, "x1"].to_numpy(), c="black", label="$y_1$"
    )
    plt.plot(
        X_oos.time_values(),
        X_oos.loc[:, "x2"].to_numpy(),
        c="blue",
        label="$y_2$ (controlled)",
    )
    plt.xlabel("t")
    plt.legend()
    plt.ylabel("$y_1, y_2$")
    plt.grid()
    plt.savefig(os.path.join(control_path, "VdP_control_nonlinear_x1x2.pdf"))
    plt.show()

    plt.figure(figsize=(10, 7))
    plt.plot(np.linalg.norm(X_oos.to_numpy(), axis=1), c="blue", label="state")
    plt.axhline(np.linalg.norm(target_state), c="red", label="target")
    plt.xlabel("$t$")
    plt.ylabel("||$ y$||")
    plt.grid()
    plt.legend()
    plt.savefig(os.path.join(control_path, "VdP_control_nonlinear_error.pdf"))
    plt.show()

    norm_pred.append(np.linalg.norm(X_oos.to_numpy(), axis=1))
    costs.append(cost[0][0])

np.save(os.path.join(control_path, "norm_pred_5runs_nonlinear"), np.array(norm_pred))
np.save(os.path.join(control_path, "costs_5runs_nonlinear"), np.array(costs))
print(f'==== Cost Values ({len(seeds)} runs) ====')
print(f'mean cost: {np.mean(costs)}')
print(f'min cost: {np.min(costs)}')
print(f'max cost: {np.max(costs)}')

mean_norm = np.mean(norm_pred, axis=0)

plt.figure(figsize=(5, 3))
for i in range(len(seeds)):
    plt.plot(norm_pred[i], c='b', linestyle='--', alpha=0.3, label=f'run 1-5')
plt.plot(mean_norm, c='orange', alpha=1, label='mean(run1-5)')
plt.plot(np.zeros_like(mean_norm), c='green', label='target')
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
plt.grid()
plt.ylabel('$|| h ||$')
plt.xlabel('$t$')
plt.savefig(os.path.join(control_path, 'VdP_control_nonlinear_5runs.pdf'))
plt.show()
