#!/usr/bin/env python3
from omegaconf import OmegaConf
import os
import pickle
from pathlib import Path
import argparse
from uimnet import utils
from uimnet import workers
from uimnet import algorithms
from uimnet import measures
import torch
from filelock import FileLock
import time
import collections


def take_measurements(algorithm, Measures):

  measures = [Measure(algorithm) for Measure in Measures]

  measurements = collections.defaultdict(list)

  for _ in range(10):
    x = torch.zeros(32, 3, 224, 224).cuda()

    for measure in measures:
      if measure.__class__.__name__ == 'Native' and not algorithm.has_native_measure:
        continue
      measurements[measure.__class__.__name__] += [measure(x)]


  measurements = {k: torch.cat(v, dim=0).detach().cpu() for k, v in dict(measurements).items()}

  return measurements




def main():
  Algorithms = [algorithms.__dict__[el] for el in [
     'ERM', 'MCDropout',
                                                   'MIMO',
                                                    'RBF', 'SoftLabeler', 'Mixup', 'RND', 'OC', 'DUE'
                                                   ]]
  # Algorithms = [algorithms.__dict__[el] for el in ['MCDropout']]
  Measures = [measures.__dict__[el] for el in [
    'Augmentations',
    'Gap',
    'Jacobian',
    'MixtureOfGaussians',
    'Entropy',
    'Largest',
    'Native'
                                               ]
              ]
  device = 'cuda:0'
  for Algorithm in Algorithms:
    algorithm = Algorithm(num_classes=266,
                          arch='resnet18',
                          device=device,
                          seed=0,
                          use_mixed_precision=True
                          )
    algorithm.initialize()
    for Measure in Measures:
      if Measure.__name__ == 'Native' and not algorithm.has_native_measure:
        continue
      algorithm.eval()
      eval_time = time.time()
      meas = Measure(algorithm)
      with torch.no_grad():
        for _ in range(10):
          x = torch.zeros(32, 3, 224, 224).to(device)
          out = meas(x)
          del x
          del out
      eval_time = time.time() - eval_time

      del meas

      utils.message(f'Evaluating {Algorithm.__name__} on {Measure.__name__}. Time taken: {eval_time:.2f} (s)')

def main2():

  Algorithms = [algorithms.__dict__[el] for el in [
     'ERM', 'MCDropout',
                                                   'MIMO',
                                                    'RBF', 'SoftLabeler', 'Mixup', 'RND', 'OC', 'DUE'
                                                   ]]
  # Algorithms = [algorithms.__dict__[el] for el in ['MCDropout']]
  device = 'cuda:0'

  Measures = [measures.__dict__[el] for el in [
    'Augmentations',
    'Gap',
    'Jacobian',
    'MixtureOfGaussians',
    'Entropy',
    'Largest',
    'Native'
                                               ]
              ]
  for Algorithm in Algorithms:
    algorithm = Algorithm(num_classes=266,
                          arch='resnet18',
                          device=device,
                          seed=0,
                          use_mixed_precision=True
                          )
    algorithm.initialize()
    # Algorithms = [algorithms.__dict__[el] for el in ['MCDropout']]
    with torch.no_grad():
      out = take_measurements(algorithm, Measures)






if __name__ == '__main__':
  main()
