import time

from openood.datasets import get_dataloader, get_ood_dataloader
from openood.evaluators import get_evaluator
from openood.networks import get_network
from openood.postprocessors import get_postprocessor
from openood.utils import setup_logger, get_config_default
from .temperature import TemperatureWrapper


class TestOODPipeline:
    def __init__(self, config) -> None:
        self.config = config

    def run(self):
        # generate output directory and save the full config file
        setup_logger(self.config)

        ood_scheme = get_config_default(self.config.evaluator, "ood_scheme", "ood")

        # get dataloader
        id_loader_dict = get_dataloader(self.config)
        ood_loader_dict = get_ood_dataloader(self.config)

        # init network
        net = get_network(self.config.network)

        # init ood evaluator
        evaluator = get_evaluator(self.config)

        # init ood postprocessor
        postprocessor = get_postprocessor(self.config)
        # setup for distance-based methods
        postprocessor.setup(net, id_loader_dict, ood_loader_dict)
        print('\n', flush=True)
        print(u'\u2500' * 70, flush=True)

        # tune temperature if necessary
        if get_config_default(self.config.pipeline, "tune_temperature", False):
            net = TemperatureWrapper(net)
            net.fit_temperature(id_loader_dict["val"])

        # start calculating accuracy
        print('\nStart evaluation...', flush=True)
        if ood_scheme == 'fsood':
            acc_metrics = evaluator.eval_acc(
                net,
                id_loader_dict['test'],
                postprocessor,
                fsood=True,
                csid_data_loaders=ood_loader_dict['csid'])
        else:
            acc_metrics = evaluator.eval_acc(net, id_loader_dict['test'],
                                             postprocessor)
        print('\nAccuracy {:.2f}%'.format(100 * acc_metrics['acc']),
              flush=True)
        print(u'\u2500' * 70, flush=True)

        # start evaluating ood detection methods
        timer = time.time()
        if ood_scheme == 'fsood':
            evaluator.eval_ood(net,
                               id_loader_dict,
                               ood_loader_dict,
                               postprocessor,
                               fsood=True)
        elif hasattr(evaluator, "eval_ood"):
            evaluator.eval_ood(net, id_loader_dict, ood_loader_dict,
                               postprocessor)
        print('Time used for eval_ood: {:.0f}s'.format(time.time() - timer))
        print('Completed!', flush=True)
