import numpy as np
import torch
from openood.evaluators.ood_edl_evaluator import OODEDLEvaluator
from openood.postprocessors import get_postprocessor
import openood.utils.comm as comm
from openood.datasets import get_dataloader, get_ood_dataloader
from openood.evaluators import get_evaluator
from openood.networks import get_network
from openood.recorders import get_recorder
from openood.trainers import get_trainer
from openood.utils import setup_logger


class TrainTestPipeline:
    def __init__(self, config) -> None:
        self.config = config

    def run(self):
        # generate output directory and save the full config file
        setup_logger(self.config)

        # set random seed
        torch.manual_seed(self.config.seed)
        np.random.seed(self.config.seed)

        # get dataloader
        loader_dict = get_dataloader(self.config)
        ood_loader_dict = get_ood_dataloader(self.config)
        # init ood postprocessor
        postprocessor = get_postprocessor(self.config)
        train_loader, val_loader = loader_dict['train'], loader_dict['val']
        test_loader = loader_dict['test']
        # init network
        net = get_network(self.config.network)
        # init trainer and evaluator
        trainer = get_trainer(net, train_loader, val_loader, self.config)
        evaluator = get_evaluator(self.config)
        ood_evaluator = OODEDLEvaluator(self.config)

        if comm.is_main_process():
            # init recorder
            recorder = get_recorder(self.config)
            print('Start training...', flush=True)

        for epoch_idx in range(1, self.config.optimizer.num_epochs + 1):
            # train and eval the model
            net, train_metrics = trainer.train_epoch(epoch_idx)
            val_metrics = evaluator.eval_acc(net, val_loader, postprocessor,
                                             epoch_idx)
            comm.synchronize()
            if comm.is_main_process():
                # save model and report the result
                recorder.save_model(net, val_metrics)
                recorder.report(train_metrics, val_metrics)

        if comm.is_main_process():
            recorder.summary()
            print(u'\u2500' * 70, flush=True)

            # evaluate on test set
            print('Start testing...', flush=True)

        # evaluate on test set
        # load checkpoint of best.ckpt

        if comm.is_main_process():
            # save model and report the result
            recorder.load_model(net)

        test_metrics = evaluator.eval_acc(net, test_loader, postprocessor=postprocessor)
        if comm.is_main_process():
            print('\nComplete Evaluation, Last accuracy {:.2f}'.format(
                100.0 * test_metrics['acc']),
                  flush=True)

        postprocessor.setup(net, loader_dict, ood_loader_dict)
        ood_evaluator.eval_ood(net, loader_dict, ood_loader_dict, postprocessor)

        for split in ['test']+[f'corrupted_test{split}' for split in range(1, 6)]:
            try:
                test_metrics = evaluator.eval_acc(net, loader_dict[split], postprocessor=postprocessor, epoch_idx=split)
                if comm.is_main_process():
                    print('\nComplete Evaluation, Last accuracy {:.2f}, ECE {:.2f}, Split: '.format(
                        100.0 * test_metrics['acc'], 100.0 * test_metrics['ece'])+f'{split}',
                        flush=True)
            except Exception as e:
                print(e)
        print('Completed!', flush=True)
