import os

import numpy as np

from sklearn.base import BaseEstimator

import keras
import tensorflow as tf

from metrics import score


class _Base():

    def __init__(self, random_state=None):
        if not random_state is None:
            np.random.seed(random_state)

        self._n_features = None
        self._n_outputs = None

    def fit(self, X, y=None):
        self._X = X
        self._y = y
        self._n_features = self._X.shape[1]
        if not self._y is None:
            self._n_outputs = self._y.shape[1]

    def _not_been_fit(self):
        raise ValueError("The model has not yet been fit. "
                         "Try to call 'fit()' first with some training data.")

    @property
    def n_features(self):
        if self._n_features is None:
            self._not_been_fit()
        return self._n_features

    @property
    def n_outputs(self):
        if self._n_outputs is None:
            self._not_been_fit()
        return self._n_outputs


class _BaseLDL(_Base):

    def predict(self, _):
        pass

    def score(self, X, y, metrics=None):
        if metrics is None:
            metrics = ["chebyshev", "clark", "canberra", "kl_divergence",
                       "cosine", "intersection"]
        return score(y, self.predict(X), metrics=metrics)


class BaseLDL(_BaseLDL, BaseEstimator):

    def __init__(self, random_state=None):
        super().__init__(random_state)


class _BaseDeep(keras.Model):

    def __init__(self, n_hidden=None, n_latent=None, random_state=None):
        keras.Model.__init__(self)
        if not random_state is None:
            tf.random.set_seed(random_state)
        self._n_latent = n_latent
        self._n_hidden = n_hidden


class BaseDeepLDL(_BaseLDL, _BaseDeep):

    def __init__(self, n_hidden=None, n_latent=None, random_state=None):
        _BaseLDL.__init__(self, random_state)
        _BaseDeep.__init__(self, n_hidden, n_latent, random_state)

    def fit(self, X, y):
        _BaseLDL.fit(self, X, y)
        self._X = tf.cast(self._X, dtype=tf.float32)
        self._y = tf.cast(self._y, dtype=tf.float32)
