#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""OOD utilities for CIFAR-10 and CIFAR-100."""

import tensorflow as tf
import datasets
import sklearn
from absl import logging




def create_ood_metrics(ood_dataset_names, ood_scores):
  """Create OOD metrics."""
  ood_metrics = {}

  for dataset_name in ood_dataset_names:
      for ood_score in ood_scores:
    
        ood_metrics.update({f'{dataset_name}_{ood_score}_ood_auprc':tf.keras.metrics.Mean()})
    
        ood_metrics.update({f'{dataset_name}_{ood_score}_ood_auroc':tf.keras.metrics.Mean()})
    

  return ood_metrics


def eval_on_ood(strategy, metrics, ood_labels, ood_scores, ood_dataset_name):
        
        """
        Utility function to compute and update out-of distribution detection metrics.
        The following ood scores are evaluated:
        
        1. the cluster distances (KL diveregence/ rate) from closest centroid.
        2. entropy of the final classifier.
        
        This metrics are not supported by tensorflow metrics and are computed externally by sklearn methods called on all test datapoints.
        
        """
        for name, val in ood_scores.items():

                precision, recall, thresholds = sklearn.metrics.precision_recall_curve(ood_labels.numpy(), val.numpy())
                auprc = sklearn.metrics.auc(x=recall, y=precision)
                logging.info("Done with OOD eval on %s, %s AUPRC %.4f",ood_dataset_name,name, auprc,)

                fpr, tpr, thresholds = sklearn.metrics.roc_curve(ood_labels.numpy(), val.numpy())
                auroc = sklearn.metrics.auc(x=fpr, y=tpr)
                logging.info("Done with OOD eval on %s, %s AUROC %.4f",ood_dataset_name,name, auroc,)

                @tf.function
                def update_ood_metrics_fn():
                    ood_dataset_name_=ood_dataset_name.replace('ood/','')
                    metrics[f'{ ood_dataset_name_}_{name}_ood_auprc'].update_state(auprc)
                    
                    metrics[f'{ ood_dataset_name_}_{name}_ood_auroc'].update_state(auroc)
              

                strategy.run(update_ood_metrics_fn)


def load_ood_datasets(ood_dataset_names,
                      in_dataset_builder,
                      in_dataset_validation_percent,
                      batch_size,
                      drop_remainder=False):
  """Load OOD datasets."""
  steps = {}
  datasets_dict = {}
  for ood_dataset_name in ood_dataset_names:
    ood_dataset_class = datasets.DATASETS[ood_dataset_name]
    ood_dataset_class = datasets.make_ood_dataset(ood_dataset_class)
    # If the OOD datasets are not CIFAR10/CIFAR100, we normalize by CIFAR
    # statistics, since all test datasets should be preprocessed the same.
    if 'cifar' not in ood_dataset_name:
      ood_dataset_builder = ood_dataset_class(
          in_dataset_builder,
          split='test',
          validation_percent=in_dataset_validation_percent,
          normalize_by_cifar=True,
          drop_remainder=drop_remainder,
          seed=0)
    else:
      ood_dataset_builder = ood_dataset_class(
          in_dataset_builder,
          split='test',
          validation_percent=in_dataset_validation_percent,
          drop_remainder=drop_remainder,
          seed=0)
    ood_dataset = ood_dataset_builder.load(batch_size=batch_size)
    steps[f'ood/{ood_dataset_name}'] = ood_dataset_builder.num_examples(
        'in_distribution') // batch_size + ood_dataset_builder.num_examples(
            'ood') // batch_size
    datasets_dict[f'ood/{ood_dataset_name}'] = ood_dataset

  return datasets_dict, steps
