# coding=utf-8
# Copyright 2023.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np
import matplotlib.pyplot as plt
from absl import app
from collections import defaultdict
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import tensorflow as tf

SMALL_SIZE = 20
MEDIUM_SIZE = 25
BIGGER_SIZE = 30

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


from absl import flags
flags.DEFINE_string('path', '/summaries_uaib/', 'relative path to tensorboard summaries')
flags.DEFINE_boolean('export_to_csv', False, 'Write cluster assignments to csv?')
FLAGS = flags.FLAGS

class FalseDict(object):
    def __getitem__(self,key):
        return 0
    def __contains__(self, key):
        return True
        


def compute_ood_metrics(dpath):
    
    
    dirs = os.listdir(dpath)
    
    
    metrics=defaultdict(list)

    
    for d in dirs:
        summary_iterator = EventAccumulator(dpath+d,size_guidance=FalseDict()).Reload()

        ood_tags=[t for t in summary_iterator.Tags()['tensors']  if 'ood' in t or 'test/accuracy' in t or 'positive_wrong_auroc' in t or 'calibration' in t]

        for t in ood_tags:
            if not '95tpr' in t and  not 'likelihood' in t:
                values = [tf.make_ndarray(e.tensor_proto).item() for e in summary_iterator.Tensors(t) ]
                metrics[t].append(values[-1])
       
    for t in metrics.keys():
            print(t)
            print(f'mean: {np.mean(metrics[t])} std:{np.std(metrics[t])} ')

        
    

  
def main(argv):
    
    cwd = os.getcwd()

    compute_ood_metrics(cwd+FLAGS.path)
    
if __name__ == '__main__':
  app.run(main)
