#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, sys
import math
import numpy as np

from sklearn import mixture
from typing import List

sys.path.append(os.pardir)
from utils.logging_settings import logger


def diff(arr: np.ndarray or List):
    """ Calculate the difference of consecutive elements.

    TODO: This function is redundant and can be easily replaced.

    Args:
        arr (np.ndarray or List): The input array.

    Returns:
        List: Calculated difference of elements.
    """

    difference = list()
    for i in range(len(arr) - 1):
        difference.append(arr[i + 1] - arr[i])
    return difference


def select_model(bics: List):
    """ Find the best model based on BIC scores.

    Args:
        bics (List): List of the available BIC scores.

    Returns:
        int: Index of the selected model.
    """
    # calculate the first and second order derivative
    diff1 = [0] + diff(bics)
    diff2 = [0] + diff(diff1)

    return diff2.index(max(diff2))


def find_limits(trajectory):
    """ Find the trajectory limits.

    TODO: Function is not optimized at all, try putting points in an array and
        using np.max function.

    Args:
        trajectory (np.ndarray): The given trajectory for finding limitations. Can be 2 or
            3 dimensions.

    Raises:
        NotSupportedError: Dimensions more than 3 are invalid.

    Returns:
        Tuple: A tuple of limits based on the dimensions (4 or 6 elements)
    """

    dimension = trajectory.shape[1]

    if dimension == 2:
        x_min, y_min = np.inf, np.inf
        x_max, y_max = -np.inf, -np.inf
        for point in trajectory:
            if point[0] < x_min:
                x_min = point[0]
            if point[0] > x_max:
                x_max = point[0]
            if point[1] < y_min:
                y_min = point[1]
            if point[1] > y_max:
                y_max = point[1]
        return (math.floor(x_min), math.ceil(x_max), math.floor(y_min), math.ceil(y_max))

    elif dimension == 3:
        x_min, y_min, z_min = np.inf, np.inf, np.inf
        x_max, y_max, z_max = -np.inf, -np.inf, -np.inf
        for point in trajectory:
            if point[0] < x_min:
                x_min = point[0]
            if point[0] > x_max:
                x_max = point[0]
            if point[1] < y_min:
                y_min = point[1]
            if point[1] > y_max:
                y_max = point[1]
            if point[2] < z_min:
                z_min = point[2]
            if point[2] > z_max:
                z_max = point[2]
        return (math.floor(x_min), math.ceil(x_max), math.floor(y_min), math.ceil(y_max),
            math.floor(z_min), math.ceil(z_max))
    else:
        raise ValueError('Dimensions more than 3 are invalid!')


def fit(trajectory: np.ndarray, is_linear: bool = False, num_components_max: int = 10):
    """ Fit gmm to a desired trajectory.

    Args:
        trajectory (np.ndarray): The main trajectory to fit the mixture model.
        is_linear (bool): Set true if the underlying data generation process is a
            linear dynamical system.
        num_components_max (int, optional): Choosing the maximum number
            of Gaussian components.
    Returns:
        mixture.GaussianMixture: the resulting GMM
    """

    # store the bic scores and the corresponding GMMs
    bics = list()
    gmms = list()
    num_components = range(1, num_components_max + 1)

    # fit the gmms
    for num in num_components:

        # fit the model
        gmm = mixture.GaussianMixture(n_components=num)
        gmm.fit(trajectory)
        gmms.append(gmm)

        # get bic score
        current_bic = gmm.bic(trajectory)
        bics.append(current_bic)

    # find the best model
    if is_linear: logger.warn('Adapting a linear model instead of bic scoring')
    gmm = gmms[(select_model(bics))] if not is_linear else gmms[0]

    # return the best model
    return gmm
