#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Calibration metrics.

Calibration is a property of probabilistic prediction models: a model is said to
be well-calibrated if its predicted probabilities over a class of events match
long-term frequencies over the sampling distribution.
"""



import tensorflow.compat.v2 as tf

from keras import backend

from  sklearn.metrics import roc_curve as roc
from  sklearn.metrics import auc as auc
from absl import logging, flags

__all__ = [
    'sparse_categorical_matches',
    'eval_calibration'
]

FLAGS = flags.FLAGS


def create_calibration_metrics(uncertainty_scores,num_backets):
    
    calibration_metrics = {}
    
    for l in uncertainty_scores:

        calibration_metrics.update({f'calibration_{l}_auroc':tf.keras.metrics.Mean()})

            
    return calibration_metrics


def sparse_categorical_matches(y_true, y_pred,dtype=tf.int32):

                reshape_matches = False
                y_pred = tf.convert_to_tensor(y_pred)
                y_true = tf.convert_to_tensor(y_true)
                y_true_org_shape = tf.shape(y_true)
                y_pred_rank = y_pred.shape.ndims
                y_true_rank = y_true.shape.ndims

                # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
                if (
                    (y_true_rank is not None)
                    and (y_pred_rank is not None)
                    and (len(backend.int_shape(y_true)) == len(backend.int_shape(y_pred)))
                ):
                    y_true = tf.squeeze(y_true, [-1])
                    reshape_matches = True
                y_pred = tf.math.argmax(y_pred, axis=-1)

                # If the predicted output and actual output types don't match, force cast
                # them to match.
                if backend.dtype(y_pred) != backend.dtype(y_true):
                    y_pred = tf.cast(y_pred, backend.dtype(y_true))
                matches = tf.cast(tf.equal(y_true, y_pred), backend.floatx())
                if reshape_matches:
                    matches = tf.reshape(matches, shape=y_true_org_shape)
                return tf.cast(matches, dtype)
                



def eval_calibration(strategy, metrics, matches, uncertainty_scores):
    
    for name, val in uncertainty_scores.items():
    
        
        non_zero = tf.cast(matches!= 0, tf.float32)
        uncertainty_correct = tf.reduce_sum(tf.math.multiply(tf.cast(matches, tf.float32), val), axis=-1) / tf.reduce_sum(non_zero, axis=-1)
        
        non_zero = tf.cast(matches== 0, tf.float32)
        uncertainty_wrong = tf.reduce_sum(tf.math.multiply(tf.cast((1-matches), tf.float32), val), axis=-1) / tf.reduce_sum(non_zero, axis=-1)
        
        logging.info(f"Uncertainty {name} of correct predictions %.4f, wrong predictions %.4f",uncertainty_correct, uncertainty_wrong)

        fpr, tpr, thresholds = roc(1-matches.numpy(), val.numpy())
        calibration_auroc= auc(x=fpr, y=tpr)
        
        logging.info(f"Done with {name} Calibration AUROC: %.4f",calibration_auroc)	
        
        @tf.function
        def update_calibration_metrics_fn():

            metrics[f'calibration_{name}_auroc'].update_state(calibration_auroc)


        strategy.run(update_calibration_metrics_fn)
