import logging
from typing import Dict, Optional

import numpy as np
import torch
from sklearn.neighbors import KNeighborsClassifier
from torch.utils.data import TensorDataset, DataLoader

logger = logging.getLogger('custom')


class NNClassifier:
    """
    Nearest-neighbor classifier
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def run(self, data, mod):
        acc = self._run(data, mod)
        self.print_results(acc)
        return acc

    def print_results(self, acc):
        for k1, cur_acc in acc.items():
            logger.info(f'\nMethod "{k1}":')
            for k2, v in cur_acc.items():
                logger.info(f'{k2}: {v:.1f}%')

    def _run(self, data, mod):
        acc = {}

        sqrtn = int(np.sqrt(len(data['y'])))
        for k in [1, sqrtn]:
            cur_acc = self.compute_crossmodal_accs(data, mod, method='knn', k=k)
            acc[f'knn_{k}'] = cur_acc

        return acc

    def compute_crossmodal_accs(self, data, mod, **kwargs):
        accs = {}

        # Caption classification
        if mod == 'x2':
            k = f'x2|x1_test_in_query'
            accs[k] = self.calculate_accuracy(
                support={'x': data[f'x2|x1'],
                         'y': self.dataset.s['y']},
                query={'x': data['x2'],
                       'y': self.dataset.s['y']},
                **kwargs
            )
            k = f'x2|x1_test_in_support'
            accs[k] = self.calculate_accuracy(
                support={'x': data['x2'],
                         'y': self.dataset.s['y']},
                query={'x': data[f'x2|x1'],
                       'y': self.dataset.s['y']},
                **kwargs
            )

        # Image classification
        elif mod == 'x1':
            k = f'x1|x2_emb_test_in_query'
            accs[k] = self.calculate_accuracy(
                support={'x': data['x1|x2_emb'],
                         'y': self.dataset.s['y']},
                query={'x': data['x1_emb'],
                       'y': self.dataset.s['y']},
                **kwargs
            )
            k = f'x1|x2_emb_test_in_support'
            accs[k] = self.calculate_accuracy(
                support={'x': data['x1_emb'],
                         'y': self.dataset.s['y']},
                query={'x': data['x1|x2_emb'],
                       'y': self.dataset.s['y']},
                **kwargs
            )
            # accs[f'x1|x2_emb_ft_test_in_query'] = self.calculate_accuracy(
            #     support={'x': data['x1|x2_emb_ft'],
            #              'y': self.dataset.s['y']},
            #     query={'x': data['x1_emb_ft'],
            #            'y': self.dataset.s['y']},
            #     **kwargs
            # )
            # accs[f'x1|x2_emb_ft_test_in_support'] = self.calculate_accuracy(
            #     support={'x': data['x1_emb_ft'],
            #              'y': self.dataset.s['y']},
            #     query={'x': data['x1|x2_emb_ft'],
            #            'y': self.dataset.s['y']},
            #     **kwargs
            # )

        else:
            raise ValueError
        return accs

    def calculate_accuracy(self,
                           support: Dict[str, torch.tensor],
                           query: Dict[str, torch.tensor],
                           bs=1024,
                           **kwargs):
        dataset = TensorDataset(query['x'], query['y'])
        loader = DataLoader(dataset, batch_size=bs)

        acc = 0
        for x, y in loader:
            cur_acc = self._calc_accuracy(
                support,
                query={'x': x, 'y': y},
                **kwargs)
            acc += cur_acc * x.size(0)
        acc /= len(dataset)
        acc *= 100
        return acc

    @staticmethod
    def _calc_accuracy(support, query, k=1, method: Optional[int] = None):
        """
        :param k: number of nearest-neighbors to compute
        :param method:
            'knn': majority vote among nearest-neighbors
            'topk': top-k accuracy
        :return: mean accuracy
        """
        neigh = KNeighborsClassifier(n_neighbors=k)
        neigh.fit(support['x'], support['y'])
        if method == 'knn':
            acc = neigh.score(query['x'], query['y'])
        elif method == 'topk':
            _, idx = neigh.kneighbors(query['x'])  # N x k
            preds = support['y'][idx]
            acc = 0
            for truth, pred in zip(query['y'], preds):
                if truth in pred:
                    acc += 1
            acc /= len(query['y'])
        else:
            raise ValueError(f'{method} is illegal method.')
        return acc
