#! /usr/bin/env python3

import torch

import os
import numpy as np
import torch.nn.functional as F




FORCE_RUN = False


def extract_features(model, trainloaderIn, testloaderIn, outloaders,num_classes, batch_size, device,config):
    dummy_input = torch.zeros((1, 3, 32, 32)).cuda(device = 1)
    score, feature_list = model.feature_list(dummy_input) 
    featdims = [feat.shape[1] for feat in feature_list]

    for split, in_loader in [('train', trainloaderIn), ('val', testloaderIn),]:
        if not os.path.exists('./knn/cache'):
            os.makedirs('./knn/cache/')

        cache_name = f"./knn/cache/{config['id_dataset']}_{split}_in_alllayers.npy"
        if FORCE_RUN or not os.path.exists(cache_name):
            feat_log = np.zeros((len(in_loader.dataset), sum(featdims)))
            score_log = np.zeros((len(in_loader.dataset), num_classes))
            label_log = np.zeros(len(in_loader.dataset))

            model.eval()
            for batch_idx, (inputs, targets) in enumerate(in_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                start_ind = batch_idx * batch_size
                end_ind = min((batch_idx + 1) * batch_size, len(in_loader.dataset))

                score, feature_list = model.feature_list(inputs)
                out = torch.cat([F.adaptive_avg_pool2d(layer_feat, 1).squeeze() for layer_feat in feature_list], dim=1)

                feat_log[start_ind:end_ind, :] = out.data.cpu().numpy()
                label_log[start_ind:end_ind] = targets.data.cpu().numpy()
                score_log[start_ind:end_ind] = score.data.cpu().numpy()
                if batch_idx % 100 == 0:
                    print(f"{batch_idx}/{len(in_loader)}")
            np.save(cache_name, (feat_log.T, score_log.T, label_log))
        else:
            feat_log, score_log, label_log = np.load(cache_name, allow_pickle=True)
            feat_log, score_log = feat_log.T, score_log.T

    for ood_loader, ood_dataset in zip(outloaders,config['ood_datasets']):
        if not os.path.exists('./knn/cache'):
            os.makedirs('./knn/cache')
        cache_name = f"./knn/cache/{ood_dataset}vs{config['id_dataset']}_out_alllayers.npy"
        if FORCE_RUN or not os.path.exists(cache_name):
            ood_feat_log = np.zeros((len(ood_loader.dataset), sum(featdims)))
            ood_score_log = np.zeros((len(ood_loader.dataset), num_classes))

            model.eval()
            for batch_idx, (inputs, _) in enumerate(ood_loader):
                inputs = inputs.to(device)
                start_ind = batch_idx * batch_size
                end_ind = min((batch_idx + 1) * batch_size, len(ood_loader.dataset))

                score, feature_list = model.feature_list(inputs)
                out = torch.cat([F.adaptive_avg_pool2d(layer_feat, 1).squeeze() for layer_feat in feature_list], dim=1)

                ood_feat_log[start_ind:end_ind, :] = out.data.cpu().numpy()
                ood_score_log[start_ind:end_ind] = score.data.cpu().numpy()
                if batch_idx % 100 == 0:
                    print(f"{batch_idx}/{len(ood_loader)}")
            np.save(cache_name, (ood_feat_log.T, ood_score_log.T))
        else:
            ood_feat_log, ood_score_log = np.load(cache_name, allow_pickle=True)
            ood_feat_log, ood_score_log = ood_feat_log.T, ood_score_log.T
