from openood.datasets import get_dataloader
from openood.evaluators import get_evaluator
from openood.networks import get_network
from openood.utils import setup_logger


class FeatExtractPipeline:
    def __init__(self, config) -> None:
        self.config = config

    def run(self):
        # generate output directory and save the full config file
        setup_logger(self.config)

        # get dataloader
        loader_dict = get_dataloader(self.config)
        test_loader = loader_dict[self.config.pipeline.extract_target]

        # init network
        net = get_network(self.config.network)

        # init evaluator
        evaluator = get_evaluator(self.config)

        # start calculating accuracy
        print('\nStart evaluation...', flush=True)
        test_metrics = evaluator.eval_acc(net, test_loader)
        print('\nComplete Evaluation, accuracy {:.2f}%'.format(
            100 * test_metrics['acc']),
              flush=True)

        # start extracting features
        print('\nStart Feature Extraction...', flush=True)
        evaluator.extract(net, test_loader)
        print('\nComplete Feature Extraction!')
