from typing import Any

import faiss
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from .base_postprocessor import BasePostprocessor
from torch.utils.data import DataLoader

normalizer = lambda x: x / np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10


class KNNPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super(KNNPostprocessor, self).__init__(config)
        # self.args = self.config.postprocessor.postprocessor_args
        # self.K = self.args.K
        self.K = 50 # cifar10: 50, imagenet:200
        self.activation_log = None

    def setup(self, net: nn.Module, id_loader: DataLoader):
        activation_log = []
        net.eval()
        with torch.no_grad():

            print('Extracting id training feature')
            for data, _ in id_loader:
            
                data = data.cuda().float()

                batch_size = data.shape[0]

                _, features = net(data, return_feature=True)

                dim = features.shape[1]
                activation_log.append(
                    normalizer(features.data.cpu().numpy().reshape(
                        batch_size, dim, -1).mean(2)))

        self.activation_log = np.concatenate(activation_log, axis=0)
        self.index = faiss.IndexFlatL2(features.shape[1])
        self.index.add(self.activation_log)

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        output, feature = net(data, return_feature=True)
        feature_normed = normalizer(feature.data.cpu().numpy())
        # feature_normed = net.cider_forward(data)
        # feature_normed = feature_normed.cpu().numpy()
        D, _ = self.index.search(
            feature_normed,
            self.K,
        )
        kth_dist = -D[:, -1]
        _, pred = torch.max(torch.softmax(output, dim=1), dim=1)
        return pred, torch.from_numpy(kth_dist)
    

    def set_hyperparam(self, hyperparam: list):
        self.K = hyperparam[0]

    def get_hyperparam(self):
        return self.K