import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datafold import InitialCondition, TSCDataFrame
from datafold.appfold.mpc import LQR
from datafold.utils._systems import VanDerPol
from sklearn.pipeline import Pipeline
from swimnetworks import Dense
from tqdm import tqdm

warnings.simplefilter(action='ignore', category=FutureWarning)

from kirnn import KIRNN

# seeds = [1,2,3,4,5]
seeds = [1]

norm_pred = []
costs = []
for seed in seeds:
    rng = np.random.default_rng(seed)
    vdp = VanDerPol(control_coord="y")

    n_timeseries = 150  # number of timeseries in training set
    n_timesteps = 50  # how many timesteps for every time series
    dt = 0.05  # delta time
    time_values = np.arange(0, n_timesteps * dt, dt)

    X_ic = rng.uniform(-3.0, 3.0, size=(n_timeseries, 2))
    idx = pd.MultiIndex.from_arrays([np.arange(n_timeseries), np.zeros(n_timeseries)])
    X_ic = TSCDataFrame(X_ic, index=idx, columns=vdp.feature_names_in_)
    U_tsc = rng.uniform(-3.0, 3.0, size=(n_timeseries, 1, 1))
    U_tsc = np.tile(U_tsc, (1, n_timesteps - 1, 1))
    U_tsc = TSCDataFrame.from_tensor(
        U_tsc,
        time_series_ids=X_ic.ids,
        feature_names=vdp.control_names_in_,
        time_values=time_values[:-1],
    )
    X_tsc, U_tsc = vdp.predict(X_ic, U=U_tsc)

    layer_width = 128
    steps = [
        ("hidden",
         Dense(layer_width=layer_width, activation='tanh', parameter_sampler='tanh',
               random_seed=rng.integers(100, size=1))),
    ]
    network_dictionary = Pipeline(steps=steps)
    kirnn_model = KIRNN(dictionary=network_dictionary, n_features_in=2, control=True)

    start_time_fit = time.time()
    kirnn_model.fit(X_tsc, U=U_tsc)
    end_time_fit = time.time()
    print(f'Fit time = {end_time_fit - start_time_fit}s.')

    # number of time steps and time values for controlled time series
    n_timesteps_oos = 200
    time_values_oos = np.linspace(0, n_timesteps_oos * X_tsc.delta_time, n_timesteps_oos)

    # random initial state
    X_ic_oos = np.array([[-1.5, -1]])
    X_ic_oos = InitialCondition.from_array(
        X_ic_oos, feature_names=['0', '1'], time_value=0
    )

    # define target state
    target_state = InitialCondition.from_array(
        np.array([0, 0]), feature_names=['x1', 'x2'], time_value=0
    )

    # check if predictions are reasonable
    time_values_dev = np.arange(0, 1000 * dt, dt)
    U_pred = 2 * np.ones(shape=(len(time_values_dev), 1))
    U_pred = TSCDataFrame.from_array(U_pred, feature_names=vdp.control_names_in_, time_values=time_values_dev)
    # predict state evolution only
    X_pred = kirnn_model.predict(X_ic_oos, U=U_pred)
    X_true, U_true = vdp.predict(X_ic_oos, U=U_pred)

    cost_running_array = np.zeros(shape=(layer_width + 2))
    cost_running_array[0] = 10
    cost_running_array[1] = 10
    cost_input = 1
    cost_running_array = np.ones(shape=layer_width)
    start_time_lqr_optimize = time.time()
    lqr = LQR(edmd=kirnn_model.edmd, cost_running=cost_running_array, cost_input=cost_input)
    end_time_lqr_optimize = time.time()
    print(f'LQR optimization time={end_time_lqr_optimize - start_time_lqr_optimize}s.')
    lqr.preset_target_state(target_state)

    # allocate data structures and fill in the following system loop
    X_oos = TSCDataFrame.from_array(
        np.zeros((n_timesteps_oos, 2)),
        feature_names=vdp.feature_names_in_,
        time_values=time_values_oos,
    )
    U_oos = TSCDataFrame.from_array(
        np.zeros((n_timesteps_oos - 1, 1)),
        feature_names=vdp.control_names_in_,
        time_values=time_values_oos[:-1],
    )

    X_oos.iloc[0, :] = X_ic_oos.to_numpy()

    for i in tqdm(range(1, n_timesteps_oos)):
        state = X_oos.iloc[[i - 1], :]
        U_oos.iloc[i - 1, :] = lqr.control_sequence(X=state)
        new_state, _ = vdp.predict(
            state, U=U_oos.iloc[[i - 1], :], time_values=time_values_oos[i - 1: i + 1]
        )
        X_oos.iloc[i, :] = new_state.iloc[[1], :].to_numpy()

    # calculate cost
    from datafold.appfold.mpc import _cost_to_array

    cost_diagonal = _cost_to_array(
        lqr.cost_running, n_elements=lqr.edmd.n_features_out_
    )
    Q = np.diag(cost_diagonal)
    lifted_target = lqr.edmd.transform(target_state).to_numpy()
    lifted_state = lqr.edmd.transform(X_oos.iloc[[-1], :]).to_numpy()
    cost = (lifted_state - lifted_target) @ Q @ (lifted_target - lifted_target).T
    for i in range(1, n_timesteps_oos - 1):
        lifted_state = lqr.edmd.transform(X_oos.iloc[[i - 1], :]).to_numpy()
        cost += ((lifted_state - lifted_target) @ Q @ (lifted_target - lifted_target).T +
                 U_oos.iloc[i - 1, :].to_numpy() ** 2 * lqr.cost_input)
    print('cost=', cost[0][0])

    trajectory_uncontrolled, _ = vdp.predict(
        X_ic_oos, U=np.zeros((n_timesteps_oos - 1)), time_values=time_values_oos
    )

    plt.figure(figsize=(5, 3))
    plt.plot(
        X_oos.loc[:, "x1"].to_numpy(),
        X_oos.loc[:, "x2"].to_numpy(),
        c="red",
        label="controlled traj.",
    )
    plt.quiver(
        *X_oos.to_numpy()[:-1, :].T,
        *np.column_stack(
            [np.zeros_like(U_oos.to_numpy()), U_oos.to_numpy() / X_oos.delta_time]
        ).T,
        color="blue",
        label="control"
    )
    plt.plot(X_oos.iloc[0, 0], X_oos.iloc[0, 1], "o", c="red")
    plt.plot(
        trajectory_uncontrolled.loc[:, "x1"].to_numpy(),
        trajectory_uncontrolled.loc[:, "x2"].to_numpy(),
        c="black",
        label="uncontrolled traj.",
    )
    plt.plot(
        trajectory_uncontrolled.iloc[0, 0],
        trajectory_uncontrolled.iloc[0, 1],
        "o",
        c="black",
        label="initial state",
    )
    plt.plot(
        target_state.iloc[0, 0],
        target_state.iloc[0, 1],
        "*",
        c="black",
        label="target state",
    )
    plt.xlabel("$y_1$")
    plt.ylabel("$y_2$")
    plt.legend(loc='lower right')
    plt.grid()
    plt.tight_layout()
    plt.savefig('VdP_control.pdf')
    plt.show()

    plt.figure(figsize=(10, 7))
    plt.plot(X_oos.time_values(), X_oos.loc[:, "x1"].to_numpy(), c="black", label="$y_1$")
    plt.plot(X_oos.time_values(), X_oos.loc[:, "x2"].to_numpy(), c="blue", label="$y_2$ (controlled)")
    plt.xlabel("t")
    plt.legend()
    plt.ylabel("$y_1, y_2$")
    plt.grid()
    plt.show()

    plt.figure(figsize=(10, 7))
    plt.plot(np.linalg.norm(X_oos.to_numpy(), axis=1), c="blue", label="state")
    plt.axhline(np.linalg.norm(target_state), c="red", label="target")
    plt.xlabel("$t$")
    plt.ylabel("||$ y$||")
    plt.grid()
    plt.legend()
    plt.savefig('VdP_control_error.pdf')
    plt.show()

    norm_pred.append(np.linalg.norm(X_oos.to_numpy(), axis=1))
    costs.append(cost[0][0])

np.save('norm_pred_5runs', np.array(norm_pred))
np.save('costs_5runs', np.array(costs))
