import tensorflow as tf
import numpy as np

class RegressionFunc:

    def __init__(self, reg_fact, normalize, normalize_output, bias_entry=None):
        self._normalize = normalize
        self._normalize_output = normalize_output
        self._bias_entry = bias_entry
        self._params = None
        self.o_std = None

    def __call__(self, inputs):
        if self._params is None:
            raise AssertionError("Model not trained yet")
        return self._feature_fn(inputs) @ self._params

    def _feature_fn(self, num_samples, dim, x):
        raise NotImplementedError

    def _normalize_features(self, features):
        mean = tf.reduce_mean(features, axis=0)
        std = tf.math.reduce_std(features, axis=0)
        num_features = int(self.dim * (self.dim + 1) / 2 + self.dim + 1)
        # do not normalize bias
        if self._bias_entry is not None:
            bias_index = tf.range(num_features)[self._bias_entry]
            index = tf.scatter_nd(tf.reshape(bias_index, [1, 1]),
                                  tf.ones(1, dtype=tf.bool), [num_features])
            mean = tf.where(index, tf.zeros(num_features), mean)
            std = tf.where(index, tf.ones(num_features), std)
        normalized_features = (features - mean) / std

        return normalized_features, mean, std

    def _normalize_outputs(self, outputs):
        mean = tf.reduce_mean(outputs)
        std = tf.math.reduce_std(outputs)
        outputs = (outputs - mean) / std
        return outputs, mean, std

    def _undo_normalization(self, params, f_mean, f_std, o_mean, o_std):
        if self._normalize_output:
            bias_index = tf.range(len(params))[self._bias_entry]
            params = params * (o_std / f_std)
            params = tf.tensor_scatter_nd_sub(params, [[bias_index]], [tf.reduce_sum(params * f_mean) + o_mean])
        else:
            params = params * (1.0 / f_std)
            bias_index = tf.range(len(params))[self._bias_entry]
            params = tf.tensor_scatter_nd_sub(params, [[bias_index]], [tf.reduce_sum(params * f_mean)])
        return params

    def fit(self, regularizer, num_samples, inputs, outputs, weights=None):
        f_mean = 0
        f_std = 1
        o_mean = 0.
        o_std = 1.
        if len(outputs.shape) > 1:
            outputs = tf.squeeze(outputs)
        features = self._feature_fn(num_samples, x=inputs)
        if self._normalize:
            features, f_mean, f_std = self._normalize_features(features)
            if self._normalize_output:
                outputs, o_mean, o_std = self._normalize_outputs(outputs)

        if weights is not None:
            if len(weights.shape) == 1:
                weights = tf.expand_dims(weights, 1)
            weighted_features = tf.transpose(weights * features)
        else:
            weighted_features = tf.transpose(features)
        # regression
        if self._no_first_order:
            reg_mat = tf.eye(self.num_features - 1) * regularizer

        else:
            reg_mat = tf.eye(self.num_features) * regularizer
        #
        if self._bias_entry is not None:
            bias_index = tf.range(len(reg_mat))[self._bias_entry]
            reg_mat = tf.tensor_scatter_nd_update(reg_mat, [[bias_index, bias_index]], [0])
        params = tf.squeeze(tf.linalg.solve(weighted_features @ features + reg_mat,
                                            weighted_features @ tf.expand_dims(outputs, 1)))
        if self._normalize:
            params = self._undo_normalization(params, f_mean, f_std, o_mean, o_std)
        return params, o_std


class LinFunc(RegressionFunc):

    def __init__(self, reg_fact, normalize, normalize_output):
        super().__init__(reg_fact, normalize, normalize_output, -1)

    def _feature_fn(self, num_samples, dim, x):
        return tf.concat([x, tf.ones([x.shape[0], 1], dtype=x.dtype)], 1)


class QuadFunc(RegressionFunc):
    # *Fits - 0.5 * x ^ T  Rx + x ^ T r + r_0 ** * /

    def __init__(self, dim, reg_fact, normalize, normalize_output, no_first_order=False, withgrad=True):
        super().__init__(reg_fact, normalize, normalize_output, bias_entry=-1)
        self.dim = dim
        self._no_first_order = no_first_order
        self.quad_term = None
        self.lin_term = None
        self.const_term = None
        self.withgrad = withgrad
        self.num_quad_features = int(tf.floor(0.5 * (self.dim + 1) * self.dim))
        self.num_features = self.num_quad_features + self.dim + 1
        self.triu_idx = tf.constant(tf.transpose(np.stack(np.where(np.triu(np.ones([dim, dim], np.bool))))))
        self.grad_feature_indices_mat = self.grad_feature_indices_non_recursive(dim)
        self.R_indices = []
        for i in range(self.dim):
            self.R_indices.append(tf.squeeze(tf.cast(tf.where(self.grad_feature_indices_mat[i]), tf.int32)))
        self.R_indices = tf.expand_dims(tf.stack(self.R_indices), 2)


    @tf.function(autograph=True)
    def grad_feature_indices_non_recursive(self, dim):
        right_part_bottom = tf.ones((1, 1), dtype=tf.bool)
        right_part_bottom_width = 1
        for i in range(dim - 1):
            tf.autograph.experimental.set_loop_options(
                shape_invariants=[(right_part_bottom, tf.TensorShape([None, None]))]
            )
            left_part = tf.concat((tf.ones((1, i + 2), dtype=tf.bool), tf.eye(i + 2, dtype=tf.bool)[1:, :]), axis=0)
            right_part_top = tf.zeros((1, right_part_bottom_width), dtype=tf.bool)
            right_part = tf.concat((right_part_top, right_part_bottom), axis=0)
            right_part_bottom = tf.concat((left_part, right_part), axis=1)
            right_part_bottom_width = right_part_bottom_width + i + 2
        return right_part_bottom

    'create gradient features'

    def grad_feature_indices(self):
        if self.dim == 1:
            return tf.ones((1, 1), dtype=tf.bool)

        left_part = tf.concat((tf.ones((1, self.dim), dtype=tf.bool), tf.eye(self.dim, dtype=tf.bool)[1:, :]), axis=0)
        right_part_bottom = self.grad_feature_indices(self.dim - 1)
        right_part_top = tf.zeros((1, right_part_bottom.shape[1]), dtype=tf.bool)
        right_part = tf.concat((right_part_top, right_part_bottom), axis=0)
        return tf.concat((left_part, right_part), axis=1)

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.int32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def get_grad_feature_matrices(self, num_samples, x):
        index_mat = self.grad_feature_indices_mat
        R_rows = []
        r_rows = []
        for i in range(self.dim):
            R_indices = tf.concat((tf.expand_dims(tf.repeat(tf.range(num_samples), self.dim), axis=1),
                                 tf.tile(self.R_indices[i], (num_samples, 1))), axis=1)
            r_indices = tf.concat((tf.reshape(tf.range(num_samples), (num_samples, 1)), tf.tile([[i]], [num_samples, 1])), axis=1)

            weighted_dat = tf.concat((x[:, :i], 2 * x[:, i:i + 1], x[:, i + 1:]), axis=1)
            R_rows.append(tf.scatter_nd(R_indices, tf.reshape(weighted_dat, [num_samples * self.dim]),
                                        [num_samples, self.num_quad_features]))
            r_rows.append(tf.scatter_nd(r_indices, tf.ones(num_samples), [num_samples, self.dim + 1]))
        R_features = tf.stack(R_rows)
        r_features = tf.stack(r_rows)
        return tf.reshape((tf.concat((R_features, r_features), axis=2)), [-1, self.num_features])

    #    @staticmethod
    #    @jit(nopython=True)
    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.int32),
                                  tf.TensorSpec(shape=[None, None], dtype=tf.float32)
                                  ])
    def _feature_fn(self, num_samples, x):
        linear_features = x
        constant_feature = tf.ones((len(x), 1))

        # quad features
        quad_features = tf.zeros((num_samples, 0))
        for i in range(self.dim):
            quad_features = tf.concat((quad_features, tf.expand_dims(x[:, i], axis=1) * x[:, i:]), axis=1)

        # stack quadratic features, linear features and constant features
        features = tf.concat((quad_features, linear_features, constant_feature), axis=1)
        if self.withgrad:
            'stack grad features with features'
            grad_features = self.get_grad_feature_matrices(num_samples, x)
            if self._no_first_order:
                features_g = grad_features[:, 0:-1]
            else:
                features_g = tf.concat((features, grad_features), axis=0)
            return features_g
        else:
            return features

    #   @tf.function(experimental_relax_shapes=True,
    #                input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32),
    #                                 tf.TensorSpec(shape=[], dtype=tf.int32),
    #                                 tf.TensorSpec(shape=[None, None], dtype=tf.float32),
    #                                 tf.TensorSpec(shape=[None], dtype=tf.float32),
    #                                 tf.TensorSpec(shape=[None], dtype=tf.float32),
    #                                 tf.TensorSpec(shape=[None], dtype=tf.float32),
    #                                 tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
    def fit(self, regularizer, num_samples, inputs, outputs, weights=None, sample_mean=None, sample_chol_cov=None):
        whitening = True
        if sample_mean is None:
            assert sample_chol_cov is None
        if sample_chol_cov is None:
            assert sample_mean is None

        # whithening
        if whitening and sample_mean is not None and sample_chol_cov is not None:
            inv_samples_chol_cov = tf.linalg.inv(sample_chol_cov)
            inputs = (inputs - sample_mean) @ tf.transpose(inv_samples_chol_cov)

        params, o_std = super().fit(regularizer, num_samples, inputs, outputs, weights)
        if self._no_first_order:
            qt = tf.scatter_nd(self.triu_idx, params[:- (self.dim)], [self.dim, self.dim])

            quad_term = - qt - tf.transpose(qt)
            lin_term = params[-(self.dim)::]

            # unwhitening:
            if whitening and sample_mean is not None and sample_chol_cov is not None:
                quad_term = tf.transpose(inv_samples_chol_cov) @ quad_term @ inv_samples_chol_cov
                t1 = tf.linalg.matvec(tf.transpose(inv_samples_chol_cov), lin_term)
                t2 = tf.linalg.matvec(quad_term, sample_mean)
                lin_term = t1 + t2
                const_term = 0

            return quad_term, lin_term, const_term, o_std

        else:

            qt = tf.scatter_nd(self.triu_idx, params[:- (self.dim + 1)], [self.dim, self.dim])

            quad_term = - qt - tf.transpose(qt)
            lin_term = params[-(self.dim + 1):-1]
            const_term = params[-1]

            # unwhitening:
            if whitening and sample_mean is not None and sample_chol_cov is not None:
                quad_term = tf.transpose(inv_samples_chol_cov) @ quad_term @ inv_samples_chol_cov
                t1 = tf.linalg.matvec(tf.transpose(inv_samples_chol_cov), lin_term)
                t2 = tf.linalg.matvec(quad_term, sample_mean)
                lin_term = t1 + t2
                const_term += tf.reduce_sum(sample_mean * (-0.5 * t2 - t1))
            return quad_term, lin_term, const_term, o_std
