import cv2
import numpy as np
import torch
from collections.abc import Iterable
import learn2learn as l2l
import pickle
import os


def get_bovw_trainer(train_data: torch.utils.data.Dataset,
                     n_sample: int = 100, k: int = 10):
    detector = cv2.SIFT_create()
    bowTrainer = cv2.BOWKMeansTrainer(k)
    for ii in range(n_sample):
        img = train_data[ii][0]
        if img.max() <= 1.:
            img = img * 255
        npimg = img.numpy().astype(np.uint8).transpose((1,2,0))
        # npimg = (img[0].numpy()*255).astype(np.uint8)
        _, descriptors = detector.detectAndCompute(npimg, None)
        if descriptors is not None:
            bowTrainer.add(descriptors.astype(np.float32))

    return bowTrainer


def get_centroids(train_data: torch.utils.data.Dataset,
                  n_sample: int = 100, k: int=10):
    """
    Compute centroids from dataset
    """
    trainer = get_bovw_trainer(train_data, n_sample=n_sample, k=k)
    centroids = trainer.cluster()
    return centroids


def get_centroids_from_tasks(dataset: l2l.data.task_dataset.TaskDataset,
                             n_sample: int = 100, k: int = 10):
    """
    Compute centroids from taskset
    """
    detector = cv2.SIFT_create()                           
    bowTrainer = cv2.BOWKMeansTrainer(k)
    for ii in range(n_sample):
        task = dataset[ii]
        for img in task[0]:
            if img.max() <= 1.:
                img = img * 255
            npimg = img.numpy().astype(np.uint8).transpose((1,2,0))
            _, descriptors = detector.detectAndCompute(npimg, None)
            if descriptors is not None:
                bowTrainer.add(descriptors.astype(np.float32))
    centroids = bowTrainer.cluster()
    return centroids


class SiftFeature:

    def __init__(self, train_data=None, k=10, n_sample=None, name=None, pkl_path=None,
                 use_cache=True, cache_dir='./cache'):
        self.k = k
        self.pickle_name = None
        
        # Set pkl path
        if pkl_path is not None:
            assert os.path.exists(pkl_path)
            self.pkl_path = pkl_path
        else:
            if n_sample is None:
                n_sample = len(train_data)
            if isinstance(train_data, l2l.data.task_dataset.TaskDataset):
                self.pickle_name = 'k' + str(k) + '_ntask' + str(n_sample) + '.pkl'
            else:
                assert isinstance(train_data, torch.utils.data.Dataset), type(train_data)
                self.pickle_name = 'k' + str(k) + '_ndata' + str(n_sample) + '.pkl'
            if name is not None:
                self.pickle_name = name + '_' + self.pickle_name
            if not os.path.exists(cache_dir):
                os.makedirs(cache_dir)
            self.pkl_path = os.path.join(cache_dir, self.pickle_name)       
        
        # Set centroid vector
        if use_cache and os.path.exists(self.pkl_path):
            with open(self.pkl_path, 'rb') as f:
                self.centroids = pickle.load(f)
            print(f'loaded {self.pkl_path}')
        else:
            if isinstance(train_data, l2l.data.task_dataset.TaskDataset):
                self.centroids =\
                    get_centroids_from_tasks(train_data, k=k, n_sample=n_sample)
            else:
                self.centroids =\
                    get_centroids(train_data, k=k, n_sample=n_sample)

            with open(self.pkl_path, 'wb') as f:
                pickle.dump(self.centroids, f)
            print(f'saved {self.pkl_path}')

        # Prepare BOVW
        self.detector = cv2.SIFT_create()
        matcher = cv2.BFMatcher()
        self.extractor = cv2.BOWImgDescriptorExtractor(self.detector, matcher)
        self.extractor.setVocabulary(self.centroids)

    def __call__(self, imgs):
        if isinstance(imgs, list) or isinstance(imgs, torch.Tensor) :
            batch = torch.cat([self._read_one_img(img) for img in imgs])
        else:
            print("not implemented")
            print(type(imgs))
            raise
        return batch

    def _read_one_img(self, img):
        if isinstance(img, torch.Tensor):
            img = img.numpy()
        else:
            assert isinstance(img, np.ndarray), f'img: {img}'

        if len(img.shape) == 3:
            if len(img) == 1:
                img = img[0]
            else:
                assert len(img) == 3, f'img.shape: {img.shape}'
        else:
            assert img.dim == 2, f'img.shape: {img.shape}'

        if img.max() <= 1.:
            img = (img * 255).astype(np.uint8).transpose((1, 2, 0))
        else:
            img = img.astype(np.uint8).transpose((1, 2, 0))
        feature = [[0] * self.k]
        keypoints = self.detector.detect(img, None)
        if keypoints is not None:
            descriptor = self.extractor.compute(img, keypoints)
            if descriptor is not None:
                feature = descriptor

        return torch.Tensor(feature)

