import tensorflow as tf
import abc
import numpy as np

from sklearn.metrics import roc_curve, roc_auc_score

from losses.losses import *

def get_loss_helper(args, class_anchors, nb_classes):
    max_angle = np.cos(args.max_angle * np.pi / 180)  
    
    print("Training with loss:", args.loss)  

    if args.loss == "crossentropy":
        return CrossEntropyHelper(args.osr_score, args.use_softmax)
    elif args.loss == "cac":
        return CACHelper(class_anchors)
    elif args.loss == "dist":
        return DistanceHelper(class_anchors, args.max_dist, 
                              nb_classes, args.nb_features, 
                              args.osr_score)
    else:
        raise ValueError(f"Loss {args.loss} not implemented")

def softmin(x):
    return tf.math.exp(-x)/tf.math.reduce_sum(tf.math.exp(-x))

def distance(x, multiple_y, metric="euclidean"):
    # compute a distance between x and each element of multiple_y
    if metric == "euclidean":
        return tf.norm(tf.expand_dims(x, 1) - multiple_y, axis=2)
    
def modify_score(score_known, score_unknown, score_type):
    """Modify the scores for AUROC calculation, we want to have score of known values
    as low as possible and score of unknown values as high as possible. This is because
    for AUROC we use 0 as the target for known values and 1 as the target for unknown values."""
    if score_type == "min":
        # the lower the score, the more likely it is to be known
        return score_known, score_unknown
    elif score_type == "max":
        # the higher the score, the more likely it is to be known
        # so we invert the sign
        return -score_known, -score_unknown

class BaseLossHelper():
    
    def __init__(self):
        pass
    
    @property
    def distance_based(self):
        return False
    
    @abc.abstractmethod
    def loss(self, y_pred, y_true):
        # return loss
        pass
    
    @abc.abstractmethod
    def predicted_class(self, y_pred, class_anchors=None):
        # return index of predicted class
        pass
    
    @abc.abstractmethod
    def osr_score(self, y_pred, class_anchors=None):
        """Compute OSR scores, this is one value per prediction."""

        # return osr score for prediction
        # y_pred is a tensor of shape (batch_size, nb_features)
        pass
    
    def predict_w_threshold(self, y_pred, threshold, type="min", class_anchors=None):
        """Return the predicted class for each prediction."""
        score = self.osr_score(y_pred, class_anchors)
        preds = self.predicted_class(y_pred, class_anchors)
        
        if type == "min":
            return tf.where(score < threshold, preds, -1)
        elif type == "max":
            return tf.where(score > threshold, preds, -1)
        else:
            raise ValueError(f"Unknown type for threshold: {type}")
        
    def _format_score(self, score_known, score_unknown):
        y_true = tf.concat([np.zeros_like(score_known), np.ones_like(score_unknown)], axis=0)
        score_known, score_unknown = modify_score(score_known, score_unknown, self.score_type)
        y_pred = tf.concat([score_known, score_unknown], axis=0)
        return y_true, y_pred
    
    def auroc(self, pred_known, pred_unknown):
        score_known = self.osr_score(pred_known)
        score_unknown = self.osr_score(pred_unknown)
        y_true, y_pred = self._format_score(score_known, score_unknown)
        return roc_auc_score(y_true, y_pred)
    
    
class CrossEntropyHelper(BaseLossHelper):
    
    def __init__(self, osr_score="min", use_softmax=False):
        self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
        self.use_softmax = use_softmax
        if osr_score in ["min", "max"]:
            self.score_type = osr_score
        else:
            raise ValueError(f"Unknown OSR score for cross entropy loss: {osr_score}")
        self.class_anchors = None
        
    def use_class_anchors(self, class_anchors):
        self.class_anchors = class_anchors
        self.score_type = "min"
        
    def loss(self, y_pred, y_true):
        return self.loss_fn(y_pred, y_true)
    
    def predicted_class(self, y_pred, class_anchors=None, threshold=None):
        if self.class_anchors is None:
            # if there is no anchor, y_pred is supposed to be logits
            preds = tf.argmax(y_pred, axis=1)
        else:
            # if there is an anchor, y_pred is the representation in the anchors space
            dist = distance(y_pred, self.class_anchors)
            preds = tf.argmin(dist, axis=1)
        
        return preds

    def osr_score(self, y_pred, class_anchors=None):
        if self.class_anchors is None:
            # if there is no anchor, y_pred is supposed to be logits
            if self.use_softmax:
                y_pred = tf.nn.softmax(y_pred)
            return tf.math.reduce_max(y_pred, axis=1)
        else:
            # if there is an anchor, y_pred is the representation in the anchors space
            dist = distance(y_pred, self.class_anchors)
            if self.score_type == "min":
                return tf.math.reduce_min(dist, axis=1)
            elif self.score_type == "max":
                return tf.math.reduce_max(dist, axis=1)
            else:
                raise ValueError(f"Unknown OSR score for cross entropy loss: {self.score_type}")
    
class DistanceHelper(BaseLossHelper):
    
    def __init__(self, class_anchors, max_dist, nb_classes, nb_features, osr_score="min"):
        self.class_anchors = class_anchors
        self.nb_classes = nb_classes
        self.nb_features = nb_features
        self.loss_fn = loss_dist(class_anchors, max_dist)
        if osr_score in ["min", "max"]:
            self.score_type = osr_score
        else:
            raise ValueError(f"Unknown OSR score for distance loss: {osr_score}")
        
    @property
    def distance_based(self):
        return True
                
    def loss(self, y_pred, y_true):
        return self.loss_fn(y_pred, y_true)
    
    def predicted_class(self, pred, class_anchors=None, threshold=None):
        if class_anchors is None:
            class_anchors = self.class_anchors
        dist = distance(pred, class_anchors)
        
        if self.score_type == "min": 
            preds = tf.math.argmin(dist, axis=1)
        elif self.score_type == "max":
            preds = tf.math.argmax(dist, axis=1)
        else:
            raise ValueError(f"Unknown OSR score for distance loss: {self.score_type}")

        return preds
                
    def osr_score(self, pred, class_anchors=None):
        if class_anchors is None:
            class_anchors = self.class_anchors
        dist = distance(pred, class_anchors)
        
        if self.score_type == "min":
            return tf.math.reduce_min(dist, axis=1)
        elif self.score_type == "max":
            return tf.math.reduce_max(dist, axis=1)
        else:
            raise ValueError(f"Unknown OSR score for distance loss: {self.score_type}")
    
    # def auroc_v2(self, pred_known, pred_unknown):
    #     nb_classes = 6
    #     nb_features = 5
    #     reshaped_pred_k = tf.reshape(pred_known, (-1, nb_classes, nb_features))
    #     reshaped_pred_u = tf.reshape(pred_unknown, (-1, nb_classes, nb_features))
    #     score_known = tf.reduce_max(tf.reduce_sum(reshaped_pred_k, axis=2), axis=1)
    #     score_unknown = tf.reduce_max(tf.reduce_sum(reshaped_pred_u, axis=2), axis=1)
    #     y_true = tf.concat([np.zeros_like(score_known), np.ones_like(score_unknown)], axis=0)
    #     score_known, score_unknown = modify_score(score_known, score_unknown, "max")
    #     y_pred = tf.concat([score_known, score_unknown], axis=0)       
    #     return roc_auc_score(y_true, y_pred)
    
class CACHelper(BaseLossHelper):
    
    def __init__(self, class_anchors):
        self.class_anchors = class_anchors
        self.loss_fn = loss_cac(class_anchors)
        self.score_type = "min"
        
    def loss(self, y_pred, y_true):
        return self.loss_fn(y_pred, y_true)
    
    def predicted_class(self, pred, class_anchors=None):
        """Return the predicted class for each prediction."""

        if class_anchors is None:
            class_anchors = self.class_anchors
        # compute distance to all anchors
        dist = distance(pred, class_anchors)
        gamma = self._get_rejection_scores(dist)  
              
        # TODO: take in account rejection of a prediction
        return tf.math.argmin(gamma, axis=1)
    
    def _get_rejection_scores(self, dist):
        return dist * (1-softmin(dist))
    
    def osr_score(self, pred, class_anchors=None):
        """Compute OSR scores, this is one value per prediction."""

        if class_anchors is None:
            class_anchors = self.class_anchors
        # compute distance to all anchors
        dist = distance(pred, class_anchors)
        
        # compute rejection scores
        gamma = self._get_rejection_scores(dist)
        
        return tf.reduce_min(gamma, axis=1)
    
    def _format_score(self, score_known, score_unknown):
        y_true = tf.concat([np.zeros_like(score_known), np.ones_like(score_unknown)], axis=0)
        y_pred = tf.concat([score_known, score_unknown], axis=0)
        return y_true, y_pred
    