from anonlibrary.loader.data_loader import SegmentationPrefetcher, SegmentationData
from anonlibrary.ontology import Ontology, WordnetConcept
from torchvision import transforms, io
import numpy as np
import os
import torch

BRODEN_CATEGORIES = ['object', 'part', 'material']


class BrodenDataset(torch.utils.data.Dataset):
    """
    Returns the images contained in
    an annotated Broden dataset.

    The BGR mean was:
    109.5388, 118.6897, 124.6901
    Here it is normalized and in RGB.
    """
    def __init__(self,
                 index,
                 mean=[0.48898, 0.46544, 0.42956],
                 std=[1, 1, 1]):

        with open(index, 'r') as fp:
            lines = fp.read().split('\n')[1:-1]

        self.index = [{'path': os.path.join(os.path.dirname(index),
                                            'images',
                                            line.split(',')[0])}
                      for line in lines]

        self.normalizer = transforms.Compose([
            transforms.ConvertImageDtype(torch.float32),
            transforms.Normalize(mean=mean, std=std)
        ])

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.index[idx]
        img_path = image['path']
        img_arr = io.read_image(img_path)

        # Repeat channels for gray images
        if (img_arr.shape[0] == 1):
            img_arr = torch.repeat_interleave(img_arr, 3, axis=0)

        # Remove channels for RGBA
        if (img_arr.shape[0] == 4):
            img_arr = img_arr[:3, :, :]

        img_arr = self.normalizer(img_arr)

        return img_arr


class BrodenConcept(WordnetConcept):
    """
    Represents a concept in the
    BrodenConcept ontology.
    """
    def __init__(self, name, hypernyms=None, hyponyms=None):
        super().__init__(name, hypernyms, hyponyms)

        # Broden IDs corresponding to this concept
        self.b_ids = {}

    def is_placeholder(self):
        return (len(self.hyponyms) == 1
                and self.hyponyms[0].b_ids == self.b_ids)


class BrodenOntology(Ontology):
    def __init__(self, directory, fname='ontology.txt',
                 alignment_fn='broden_wordnet_alignment.csv',
                 vanilla_nd=False):
        # Retrieve ontology from file
        with open(os.path.join(directory, fname)) as fp:
            lines = fp.read().split('\n')[:-1]

        # Parse is_a relationships
        raw_ontology = [e.split() for e in lines]

        # Partition the synsets
        hypernyms = set([e[0] for e in raw_ontology])
        hyponyms = set([e[1] for e in raw_ontology])
        synsets = hypernyms.union(hyponyms)
        root_syn = list(hypernyms - hyponyms)[0]

        # Build nodes
        nodes = {s: BrodenConcept(s) for s in synsets}

        # Assign Broden labels to WordNet concepts
        with open(os.path.join(directory, alignment_fn)) as fp:
            labels = fp.read().split('\n')[:-1]
            labels = [e.split(',') for e in labels]

        for synset in nodes:
            concept = nodes[synset]
            concept.b_ids = {int(e[0]) for e in labels if e[2] == concept.name}

        # Connect nodes
        for hypernym, hyponym in raw_ontology:
            nodes[hypernym].hyponyms += [nodes[hyponym]]
            nodes[hyponym].hypernyms += [nodes[hypernym]]

        # Identify root
        root = nodes[root_syn]

        # Cumulate Broden labels to eventually
        # retrieve higher level visual concepts
        if not vanilla_nd:
          for synset in nodes:
              concept = nodes[synset]
              for descendant in concept.get_descendants():
                  concept.b_ids |= descendant.b_ids

        # Init superclass
        super().__init__(root)


class BrodenAnnotations:
    def __init__(self, index):
        data_directory = os.path.dirname(index)
        concept_index = os.path.basename(index)
        self.data = SegmentationData(data_directory, concept_index,
                                     categories=['object', 'part', 'material'],
                                     filter_images=False)

        # FIXME: workaround to provide path
        self.index = [{'path': self.data.filename(i)}
                      for i in range(self.data.size())]

        self.start = 0
        self.end = self.data.size()
        self.pos = 0
        self.batch_size = 1

    def __len__(self):
        return self.data.size()

    def __iter__(self):
        self.pos = self.start
        return self

    def __getitem__(self, idx):
        return self.index[idx]

    def __next__(self):
        # Retrieve batch
        batch = self.prefetcher.fetch_batch()

        if batch is None:
            raise StopIteration

        # Update counter
        self.pos += len(batch)

        return [BrodenExample(img) for img in batch]

    def __call__(self, start, end, batch_size):
        self.start = start
        if end is not None:
            self.end = end
        else:
            self.end = self.data.size()
        self.batch_size = batch_size
        self.prefetcher = SegmentationPrefetcher(self.data,
                                                 categories=self.data.
                                                 category_names(),
                                                 once=True, batch_size=self.
                                                 batch_size,
                                                 ahead=4, start=self.start,
                                                 end=self.end, thread=False,
                                                 n_procs=0)
        return self


class BrodenExample:
    def __init__(self, dict_example):
        # A 'dict_example' is a dictionary for
        # each image with the following keys:
        #   fn:str  filename
        #   i:int   unique index
        #   sh:int  map height resolution
        #   sw:int  map width resolution
        #   color:np.Array  array with shape (sh, sw) containing
        #                   in each position the index of the
        #                   given color in the original image.
        #                   The same holds for the keys object,
        #                   part, scene and texture.
        self.dict_example = dict_example

    @property
    def index(self):
        return self.dict_example['i']

    @property
    def shape(self):
        return (self.dict_example['sh'], self.dict_example['sw'])

    def get_broden_ids(self):
        pixels = []
        scalars = []

        for category in BRODEN_CATEGORIES:
            category_map = self.dict_example[category]
            shape = np.shape(category_map)

            # NOTE: why this?
            if len(shape) % 2 == 0:
                category_map = [category_map]

            if len(shape) < 2:
                # Scalar annotation
                scalars += category_map
            else:
                # Pixel-level annotation
                pixels.append(category_map)

        # Retrieve unique broden ids
        b_ids = [scalar for scalar in scalars]
        for p in pixels:
            b_ids += list(np.argwhere(np.bincount(p.ravel()) > 0)[:, 0])
        # '0' is not a broden id
        b_ids = {i for i in b_ids if i != 0}

        return b_ids

    def select_concepts(self, concepts):
        # Broden IDs contained in the image
        b_ids = self.get_broden_ids()

        # Select concepts with relevant broden IDs
        selected_concepts = [c for c in concepts if c.b_ids & b_ids]

        return selected_concepts

    def get_concept_mask(self, concept, c_mask=None):
        '''
        Given a concept and an image
        returns the concept map L_c(x)
        '''

        # Init mask
        c_mask &= False

        # TODO: this should be fixed for Broden-like
        #       datasets that use different categories
        for category in BRODEN_CATEGORIES:
            category_map = self.dict_example[category]
            shape = np.shape(category_map)

            # Scalar annotation
            if len(shape) < 2:

                # The category has not any concept map
                if shape[0] == 0:
                    continue

                # All of the image contains
                # one of the leaves as a visual
                # concept, it is irrelevant to
                # continue since all of the mask
                # is therefore active.
                scalar = category_map[0]
                if scalar in concept.b_ids:
                    c_mask |= True
                    return c_mask

            # Pixel-by-pixel annotation
            elif len(shape) == 3:
                for cid in concept.b_ids:
                    for i in range(shape[0]):
                        c_mask |= category_map[i] == cid

        return c_mask
