import numpy as np
import json
import os
import pickle

from numpy.random import default_rng
from collections import namedtuple, defaultdict
from torchvision.datasets.utils import download_url
from tqdm import tqdm

from comln.datasets.base import MetaDataset


class LEOMetaDataset(MetaDataset):
    folder = 'leo'
    zip_url = 'http://storage.googleapis.com/leo-embeddings/embeddings.zip'
    shape = (640,)

    def __init__(
            self,
            root,
            batch_size,
            shots=5,
            ways=5,
            test_shots=None,
            size=None,
            crop='center',
            split='train',
            seed=0,
            download=False
        ):
        self.crop = crop
        super().__init__(root, batch_size, shots=shots, ways=ways,
            test_shots=test_shots, size=size, split=split, seed=seed,
            download=download)
        self.load_data()

    def load_data(self):
        if self._data is None:
            arrays, labels2indices = [], defaultdict(list)
            offset = 0
            for filename in self.split_filenames:
                with open(filename, 'rb') as f:
                    data = pickle.load(f, encoding='latin')
                    arrays.append(data['embeddings'])

                    for i, key in enumerate(data['keys']):
                        _, class_name, _ = str(key).split('-')
                        labels2indices[class_name].append(i + offset)

                    offset += data['embeddings'].shape[0]
            self._data = np.concatenate(arrays, axis=0)
            self._labels2indices = dict((k, np.asarray(v))
                for (k, v) in labels2indices.items())
        return self

    @property
    def data(self):
        return self._data

    @property
    def labels2indices(self):
        return self._labels2indices

    @property
    def split_filenames(self):
        return tuple(os.path.join(self.root, 'embeddings', self.name, self.crop,
            f'{split}_embeddings.pkl') for split in self.splits)

    def _check_integrity(self):
        return all(map(os.path.isfile, self.split_filenames))

    def download(self):
        import zipfile

        if self._check_integrity():
            return

        # Download dataset
        download_url(self.zip_url, self.root)

        # Extract dataset
        filename = os.path.join(self.root, os.path.basename(self.zip_url))
        folder, _ = os.path.splitext(filename)
        if not os.path.isdir(folder):
            with zipfile.ZipFile(filename, 'r') as f:
                f.extractall(self.root)

        if os.path.isfile(filename):
            os.remove(filename)


class LEOMiniImagenet(LEOMetaDataset):
    name = 'miniImageNet'


class LEOTieredImagenet(LEOMetaDataset):
    name = 'tieredImageNet'
