#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import pickle
import numpy as np

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from utils.logging_settings import logger

from policy_interface import PlanningPolicyInterface


class PLY_REG_DS(PlanningPolicyInterface):
    """ Approximation of a dynamical system using a polynomial regression with no stability guarantees.

    Since a DS dataset can be seen as a time series data, with velocities acting as labels, we seek
    to find a Sum of Squares polynomial to estimate the DS and later ensure stability.
    """

    def __init__(self, maximum_degree: int = 5, data_dim: int = 2, plot_model: bool = False):
        """ Initialize a nonlinear DS estimator.

        Args:
            maximum_degree (int, optional): Maximum degree of the polynomial.

            data_dim (int, optional): Dimension of the input data. Defaults to 2.
            plot_model (bool, optional): Choose to plot or not. Defaults to False.
        """

        self.__max_degree = maximum_degree
        logger.info(f'{self.__max_degree}-degree polynomial initialized')

        self.__regressor: LinearRegression = LinearRegression()
        self.__poly_transform = PolynomialFeatures(degree=self.__max_degree)
        self.__plot: bool = plot_model

    def fit(self, trajectory: np.ndarray, velocity: np.ndarray,
        trajectory_test: np.ndarray = None, velocity_test: np.ndarray = None):
        """Fit a polynomial model to estimate a dynamical systems.

        Args:
            trajectory (np.ndarray): Trajectory data in shape (samples, features).
            velocity (np.ndarray): Velocity data in shape (samples, features).
            trajectory_test (np.ndarray, optional): Test data points. Defaults to None.
            velocity_test (np.ndarray, optional): Test data labels. Defaults to None.
        """

        # fit the best polynomial
        features_poly = self.__poly_transform.fit_transform(trajectory)
        logger.info(f'Features are: \n{self.__poly_transform.get_feature_names_out()}')
        self.__regressor.fit(features_poly, velocity)

    def predict(self, trajectory: np.ndarray):
        """ Predict estimated velocities from learning PLY_DS.

        Args:
            trajectory (np.ndarray): Trajectory in shape (sample size, dimension).

        Returns:
            np.ndarray: Estimated velocities in shape (sample size, dimension).
        """

        features_poly = self.__poly_transform.fit_transform(trajectory)
        return self.__regressor.predict(features_poly)


    def load(self, model_name: str, dir: str = '../res'):
        """ Load the torch model.

        Args:
            model_name (str): Name of the model.
            dir (str, optional): Load directory. Defaults to '../res'.
        """

        self.__regressor = pickle.load(open(os.path.join(dir, f'{model_name}_ply{self.__max_degree}.pickle'), 'rb'))


    def save(self, model_name: str, dir: str = '../res'):
        """ Save the torch model.

        Args:
            model_name (str): Name of the model.
            dir (str, optional): Save directory. Defaults to '../res'.
        """

        os.makedirs(dir, exist_ok=True)

        pickle.dump(self.__regressor, open(os.path.join(dir, f'{model_name}_ply{self.__max_degree}.pickle'), 'wb'))
