import io
import json
import torch
import numpy as np

from IPython import embed
import spacy

from allennlp.data import Instance
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.fields import TextField, MultiLabelField, MetadataField


class FineEntityTyping(DatasetReader):
    def __init__(self, label_to_idx, rank_loss=False):
        super().__init__(lazy=False)
        self.token_indexers = {"tokens": SingleIdTokenIndexer()}
        self.feature_token_indexers = {"features": SingleIdTokenIndexer(namespace='features')}
        self.label_to_idx = label_to_idx
        self.rank_loss = rank_loss

    def text_to_instance(self, line):

        example = json.loads(line)
        start = example['start']
        end = example['end']
        words = example['tokens']
        labels = example['labels']

        start = int(start)
        end = int(end)

        left_tokens = words[:start]
        right_tokens = words[end:]

        mention_tokens = [Token(token) for token in words[start:end]]
        left_tokens = [Token(token) for token in left_tokens[-min(len(left_tokens), 10):]]
        right_tokens = [Token(token) for token in right_tokens[:min(len(right_tokens), 10)]]

        # convert all of them to text fields
        mention_tokens = TextField(mention_tokens, self.token_indexers)
        left_tokens = TextField(left_tokens, self.token_indexers)
        right_tokens = TextField(right_tokens, self.token_indexers)

        # TODO: labels
        label_idx = [self.label_to_idx[label] for label in labels]

        if self.rank_loss:
            label_idx = label_idx + [-1] * (len(self.label_to_idx) - len(label_idx))
            label_field = MetadataField({'idx': label_idx})
        else:
            label_field = MultiLabelField(label_idx,
                                          skip_indexing=True,
                                          num_labels=len(self.label_to_idx))

        data = {
            'mention_tokens': mention_tokens,
            'left_tokens': left_tokens,
            'right_tokens': right_tokens,
            'labels': label_field
        }

        return Instance(data)

    def _read(self, file_path):
        with open(file_path) as f:
            lines = f.readlines()
            for line in lines:
                yield(self.text_to_instance(line))