from anonlibrary.ontology import Ontology, WordnetConcept
from torchvision import transforms, io
import json
import os
import torch


class ImageNetDataset(torch.utils.data.Dataset):
    """
    Returns the images contained in
    an annotated ImageNet dataset.
    """
    def __init__(self,
                 index,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]):

        with open(index, 'r') as fp:
            self.index = json.load(fp)

        self.data_directory = os.path.dirname(index)

        self.normalizer = transforms.Compose([
            transforms.ConvertImageDtype(torch.float32),
            transforms.Normalize(mean=mean, std=std)
        ])

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.index[idx]
        img_path = image['path']
        img_arr = io.read_image(os.path.join(self.data_directory,img_path))

        # Repeat channels for gray images
        if (img_arr.shape[0] == 1):
            img_arr = torch.repeat_interleave(img_arr, 3, axis=0)

        # Remove channels for RGBA
        if (img_arr.shape[0] == 4):
            img_arr = img_arr[:3, :, :]

        img_arr = self.normalizer(img_arr)

        return img_arr


class ImageNetAnnotations:
    """
    Returns the annotations contained in
    an annotated ImageNet dataset.
    """
    def __init__(self, index):
        with open(index, 'r') as fp:
            self.index = json.load(fp)

        # Adjust path
        self.data_directory = os.path.dirname(index)
        for img in self.index:
          img['path'] = os.path.join(self.data_directory,img['path'])

        self.start = 0
        self.end = len(self.index)
        self.pos = 0
        self.batch_size = 1

    def __len__(self):
        return len(self.index)

    def __iter__(self):
        self.pos = self.start
        return self

    def __getitem__(self, idx):
        return self.index[idx]

    def __next__(self):
        # Retrieve batch
        batch = self.index[self.pos:min(self.pos+self.batch_size, self.end)]

        if not batch:
            raise StopIteration

        # Update counter
        self.pos += len(batch)

        return [ImageNetExample(img) for img in batch]

    def __call__(self, start, end, batch_size):
        self.start = start
        if end is not None:
            self.end = end
        else:
            self.end = len(self.index)
        self.batch_size = batch_size
        return self


class ImageNetOntology(Ontology):
    def __init__(self, directory):
        # Retrieve ontology from file
        with open(directory+'/ontology.txt') as fp:
            lines = fp.read().split('\n')[:-1]

        # Parse is_a relationships
        raw_ontology = [e.split() for e in lines]

        # Partition the synsets
        hypernyms = set([e[0] for e in raw_ontology])
        hyponyms = set([e[1] for e in raw_ontology])
        synsets = hypernyms.union(hyponyms)
        root_syn = list(hypernyms - hyponyms)[0]

        # Build nodes
        nodes = {s: WordnetConcept(s) for s in synsets}

        # Connect nodes
        for hypernym, hyponym in raw_ontology:
            nodes[hypernym].hyponyms += [nodes[hyponym]]
            nodes[hyponym].hypernyms += [nodes[hypernym]]

        # Identify root
        root = nodes[root_syn]

        # Init superclass
        super().__init__(root)


class ImageNetExample:
    def __init__(self, dict_example):
        self.dict_example = dict_example
        self.synsets = set([e[0] for e in self.dict_example['boxes']])
        self.intersection = {}

    @property
    def index(self):
        return self.dict_example['idx']

    @property
    def shape(self):
        return (self.dict_example['height'], self.dict_example['width'])

    def intersect(self, concept):
        # Intersect the concept leaves names
        # and the synsets of the current img
        return {c.name for c in concept.leaves if c.name in self.synsets}

    def select_concepts(self, concepts):
        # For each concept compute the intersection
        self.intersection = {c.name: self.intersect(c) for c in concepts}

        # Select only concepts with a non empty intersection
        selection = [c for c in concepts if self.intersection[c.name]]

        return selection

    def _get_boxes(self, synset):
        return [e[1] for e in self.dict_example['boxes'] if e[0] == synset]

    def get_concept_mask(self, concept, c_mask=None):
        '''
        Given a concept and an image
        returns the concept map L_c(x)
        '''
        # Init mask
        c_mask &= False

        # Access to synsets relevant to the concept
        synsets = self.intersection[concept.name]

        # Synsets contained in the image
        for synset in synsets:
            # Construct bounding boxes
            # for the synset
            for box in self._get_boxes(synset):
                pmin_x, pmin_y = box[0]
                pmax_x, pmax_y = box[1]

                # Consider only positive
                # bounding boxes
                # NOTE: this could be done
                #       by the preprocesser
                pmin_x = max(int(pmin_x), 0)
                pmin_y = max(int(pmin_y), 0)
                pmax_x = max(int(pmax_x), 0)
                pmax_y = max(int(pmax_y), 0)

                # Set the bounding box as true
                c_mask[pmin_x:pmax_x, pmin_y:pmax_y] = True

        return c_mask
