"""

code adapted from:
https://github.com/a-lucic/focus
"""

from typing import Dict, Optional

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.tree import DecisionTreeClassifier
from library.up import trees

def distance_func(name, x1, x2, eps: float = 0.0):
    if name == "l1":
        ax = 1
        return l1_dist(x1, x2, ax, eps)
    if name == "l2":
        ax = 1
        return l2_dist(x1, x2, ax, eps)
    if name == "cosine":
        ax = -1
        return cosine_dist(x1, x2, ax, eps)


def l1_dist(x1, x2, ax: int, eps: float = 0.0):
    # sum over |x| + eps, i.e. L1 norm
    x = x1 - x2
    return tf.reduce_sum(tf.abs(x), axis=ax) + eps


def l2_dist(x1, x2, ax: int, eps: float = 0.0):
    # sqrt((sum over x^2) + eps)), i.e. L2 norm
    x = x1 - x2
    return (tf.reduce_sum(x**2, axis=ax) + eps) ** 0.5


def cosine_dist(x1, x2, ax: int, eps: float = 0.0):
    # normalize by sqrt(max(sum(x**2), 1e-12))
    normalize_x1 = tf.nn.l2_normalize(x1, dim=1)
    normalize_x2 = tf.nn.l2_normalize(x2, dim=1)
    dist = (
        tf.losses.cosine_distance(
            normalize_x1,
            normalize_x2,
            axis=ax,
            reduction=tf.compat.v1.losses.Reduction.NONE,
        )
        + eps
    )
    dist = tf.squeeze(dist)
    dist = tf.cast(dist, tf.float64)
    return dist

def _filter_hinge_loss(n_class, mask_vector, features, sigma, temperature, model_fn):
    n_input = features.shape[0]

    # if mask_vector all 0, i.e. all labels flipped
    if not np.any(mask_vector):
        return np.zeros((n_input, n_class))

    # filters feature input based on the mask
    filtered_input = tf.boolean_mask(features, mask_vector)

    # if sigma or temperature are not scalars
    if type(sigma) != float or type(sigma) != int:
        sigma = tf.boolean_mask(sigma, mask_vector)
    if type(temperature) != float or type(temperature) != int:
        temperature = tf.boolean_mask(temperature, mask_vector)

    # compute loss
    filtered_loss = model_fn(filtered_input, sigma, temperature)

    indices = np.where(mask_vector)[0]
    zero_loss = np.zeros((n_input, n_class))
    # add sparse updates to an existing tensor according to indices
    hinge_loss = tf.tensor_scatter_nd_add(
        tensor=zero_loss, indices=indices[:, None], updates=filtered_loss
    )
    return hinge_loss


class FOCUS:

    def __init__(self, mlmodel, classes, optimizer="adam", lr=0.001, n_class=2, n_iter=1000, sigma_val=1.0, 
                         temperature=1.0, distance_weight=0.01, distance_func="l1") -> None:

        if optimizer == "adam":
            self.optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=lr
            )
        elif optimizer == "gd":
            self.optimizer = tf.compat.v1.train.GradientDescentOptimizer(
                learning_rate=lr
            )

        self.mlmodel = mlmodel

        self.n_class = n_class
        self.n_iter = n_iter
        self.sigma_val = sigma_val
        self.temp_val = temperature
        self.distance_weight_val = distance_weight
        self.distance_function = distance_func
        self.classes = classes

    def get_counterfactual(self, factual):
        return self.get_counterfactuals(np.array([[1, 1], [2, 2]]))

    def get_counterfactuals(self, factuals):

        best_perturb = np.array([])

        def f(best_perturb):
            # doesn't work with categorical features, so they aren't used
            # original_input = self.mlmodel.get_ordered_features(factuals)
            # original_input = original_input.to_numpy()
            original_input = factuals
            ground_truth = self.mlmodel.predict(original_input)

            # these will be the perturbed features, i.e. counterfactuals
            perturbed = tf.Variable(
                initial_value=original_input, name="perturbed_features", trainable=True
            )
            to_optimize = [perturbed]

            class_index = np.zeros(len(original_input), dtype=np.int64)
            # print(class_index, original_input, self.classes)
            # print(ground_truth)
            # for i, class_name in enumerate(self.classes):
            #     mask = np.equal(ground_truth, class_name)
            #     print("M", mask)
            #     class_index[mask] = i
            class_index = tf.constant(class_index, dtype=tf.int64)
            example_range = tf.constant(np.arange(len(original_input), dtype=np.int64))
            example_class_index = tf.stack((example_range, class_index), axis=1)

            # booleans to indicate if label has flipped
            indicator = np.ones(len(factuals))

            # hyperparameters
            sigma = np.full(len(factuals), self.sigma_val)
            temperature = np.full(len(factuals), self.temp_val)
            distance_weight = np.full(len(factuals), self.distance_weight_val)

            best_distance = np.full(len(factuals), 1000.0)
            best_perturb = np.zeros(perturbed.shape)

            for i in range(self.n_iter):
                with tf.GradientTape(persistent=True) as t:
                    p_model = _filter_hinge_loss(
                        self.n_class,
                        indicator,
                        perturbed,
                        sigma,
                        temperature,
                        self._prob_from_input,
                    )
                    approx_prob = tf.gather_nd(p_model, example_class_index)

                    eps = 10.0**-10
                    distance = distance_func(
                        self.distance_function, perturbed, original_input, eps
                    )

                    # the losses
                    prediction_loss = indicator * approx_prob
                    distance_loss = distance_weight * distance
                    total_loss = tf.reduce_mean(prediction_loss + distance_loss)
                    # optimize the losses
                    grad = t.gradient(total_loss, to_optimize)
                    self.optimizer.apply_gradients(
                        zip(grad, to_optimize),
                        global_step=tf.compat.v1.train.get_or_create_global_step(),
                    )
                    # clip perturbed values between 0 and 1 (inclusive)
                    tf.compat.v1.assign(
                        perturbed, tf.math.minimum(1, tf.math.maximum(0, perturbed))
                    )

                    true_distance = distance_func(
                        self.distance_function, perturbed, original_input, 0
                    ).numpy()

                    # get the class predictions for the perturbed features
                    current_predict = self.mlmodel.predict(perturbed.numpy())
                    indicator = np.equal(ground_truth, current_predict).astype(
                        np.float64
                    )

                    # get best perturbation so far, did prediction flip
                    mask_flipped = np.not_equal(ground_truth, current_predict)
                    # is distance lower then previous best distance
                    mask_smaller_dist = np.less(true_distance, best_distance)

                    # update best distances
                    temp_dist = best_distance.copy()
                    temp_dist[mask_flipped] = true_distance[mask_flipped]
                    best_distance[mask_smaller_dist] = temp_dist[mask_smaller_dist]

                    # update best perturbations
                    temp_perturb = best_perturb.copy()
                    temp_perturb[mask_flipped] = perturbed[mask_flipped]
                    best_perturb[mask_smaller_dist] = temp_perturb[mask_smaller_dist]

            return best_perturb

        # Little bit hacky, but needed as other tf code is graph based.
        # Graph based tf and eager execution for tf don't work together nicely.
        with tf.compat.v1.Session() as sess:
            pf = tf.compat.v1.py_func(f, [best_perturb], tf.float32)
            best_perturb = sess.run(pf)

        df_cfs = pd.DataFrame(best_perturb, columns=self.mlmodel.data.continuous)
        print(df_cfs)
        assert False
        # df_cfs = pd.DataFrame(best_perturb, columns=self.mlmodel.data.continuous)
        # df_cfs = check_counterfactuals(self._mlmodel, df_cfs, factuals.index)
        # df_cfs = self._mlmodel.get_ordered_features(df_cfs)
        return df_cfs

    def _prob_from_input(self, perturbed, sigma, temperature):
        feat_columns = self.mlmodel.data.continuous
        if not isinstance(self.mlmodel.raw_model, DecisionTreeClassifier):
            return trees.get_prob_classification_forest(
                self.mlmodel,
                feat_columns,
                perturbed,
                sigma=sigma,
                temperature=temperature,
            )
        elif isinstance(self.mlmodel.raw_model, DecisionTreeClassifier):
            return trees.get_prob_classification_tree(
                self.mlmodel.raw_model, feat_columns, perturbed, sigma=sigma
            )