#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import time, os
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from policy_interface import PlanningPolicyInterface
from torch.utils.data import DataLoader, TensorDataset

from tqdm.auto import tqdm

from utils.utils import mse, normalize
from utils.network_types import NN, LSTM, LNET
from utils.logging_settings import logger


class NL_DS(PlanningPolicyInterface):
    """ Approximation of a dynamical system using nonlinear approaches.

    # TODO: Follow ideas on lipchitz and also stability via augmenting stable data!

    Since a DS dataset can be seen as a time series data, with velocities acting as labels, NN
    networks sounds like a plausible option for estimating nonlinear DS.
    """

    def __init__(self, network: str = 'nn', data_dim: int = 2, plot_model: bool = False):
        """ Initialize a nonlinear DS estimator.

        Note: the 'nn' method is equivalent to using behavioral cloning.

        Args:
            network (str, optional): Network type. So far could be nn (Neural Network)
                or lstm (Recurrent Neural Networks).

            data_dim (int, optional): Dimension of the input data. Defaults to 2.
            plot_model (bool, optional): Choose to plot or not. Defaults to False.
        """

        self.__network_type = network

        self.__nl_module: nn.Module = None
        if self.__network_type == 'nn':
            self.__nl_module = NN(input_shape=data_dim, output_shape=data_dim)
        if self.__network_type == 'lstm':
            self.__nl_module = LSTM(input_shape=data_dim, output_shape=data_dim)
        if self.__network_type == 'lnet':
            self.__nl_module = LNET(input_shape=data_dim, output_shape=data_dim)
        self.__nl_module.parameters()

        logger.info(f'{network.upper()} network initialized')

        self.__dataset: DataLoader = None
        self.__plot: bool = plot_model

    def fit(self, trajectory: np.ndarray, velocity: np.ndarray, n_epochs: int = 200, batch_size: int = 128,
        show_ds: bool = False, title: str = None, show_stats: bool = True, stat_freq: int = 50,
        trajectory_test: np.ndarray = None, velocity_test: np.ndarray = None, normalize: bool = False):
        """ Fit a nonlinear model to estimate a dynamical systems.

        Args:
            trajectory (np.ndarray): Trajectory data in shape (samples, features).
            velocity (np.ndarray): Velocity data in shape (samples, features).
            show_ds (bool, optional): Whether to show the final DS or not. Defaults to False.
            title (str, optional): Plot title for the model. Defaults to None.
            show_stats (bool, optional): Show training statistics. Defaults to False.
        """

        # build the dataset
        self.__dataset = self._prepare_dataset(trajectory, velocity, batch_size, normalize)

        if velocity_test is not None and normalize:
            velocity_test = normalize(velocity_test)

        # build the optimizer
        optimizer = optim.Adam(self.__nl_module.parameters(), lr=0.001)
        criterion = nn.MSELoss()

        start_time = time.time()
        logger.info('Starting the training sequence')

        for epoch in (par := tqdm(range(n_epochs))):
            trajs_t, vels_t = next(iter(self.__dataset))

            # prediction and loss
            prediction = self.__nl_module(trajs_t)
            loss = criterion(prediction, vels_t)

            # optimization and back-propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # tracking the learning process
            if show_stats and epoch % stat_freq == 0:
                if trajectory_test is None:
                    par.set_description(f'MSE Train > {(loss.data.numpy() * 2):.4f}')
                else:
                    par.set_description(f'MSE Train > {(loss.data.numpy() * 2):.4f} | MSE Test > {mse(self.predict(trajectory_test), velocity_test):.4f}')

        total_time = time.time() - start_time
        logger.info(f'Training concluded in {total_time:.4f} seconds')


    def predict(self, trajectory: np.ndarray):
        """ Predict estimated velocities from learning NN_DS.

        Args:
            trajectory (np.ndarray): Trajectory in shape (sample size, dimension).

        Returns:
            np.ndarray: Estimated velocities in shape (sample size, dimension).
        """

        x = torch.from_numpy(trajectory.astype(np.float32))

        if self.__network_type == 'lstm':
            x = torch.reshape(x, (x.shape[0], 1, x.shape[1]))

        res = self.__nl_module(x)
        return res.detach().cpu().numpy()


    def load(self, model_name: str, dir: str = '../res'):
        """ Load the torch model.

        Args:
            model_name (str): Name of the model.
            dir (str, optional): Load directory. Defaults to '../res'.
        """

        self.__nl_module = torch.load(os.path.join(dir, f'{model_name}.pt'))


    def save(self, model_name: str, dir: str = '../res'):
        """ Save the torch model.

        Args:
            model_name (str): Name of the model.
            dir (str, optional): Save directory. Defaults to '../res'.
        """

        os.makedirs(dir, exist_ok=True)
        torch.save(self.__nl_module, os.path.join(dir, f'{model_name}.pt'))


    def _prepare_dataset(self, trajs: np.ndarray, vels: np.ndarray, batch_size: int,
                         normalize: bool):
        """ Convert npy data to tensor dataset.

        Args:
            trajs (np.ndarray): Demonstrated trajectories.
            vels (np.ndarray): Demonstrated velocities.
            batch_size (int): Size of data batches for the loader.
        """

        # normalize velocity vectors
        vels = normalize(vels) if normalize else vels

        # convert npy to tensor
        x, y = torch.from_numpy(trajs.astype(np.float32)), torch.from_numpy(vels.astype(np.float32))

        if self.__network_type == 'lstm':
            x = torch.reshape(x, (x.shape[0], 1, x.shape[1]))

        # generate a dataloader
        dataset = TensorDataset(x, y)
        return DataLoader(dataset, batch_size=batch_size, shuffle=True)

