import torch
from IPython import embed

from typing import Iterator, List, Dict
import torch.optim as optim
import numpy as np

import csv
import torch
import torch.nn as nn

# train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS']

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField, LabelField

from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer

from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model

from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits

from allennlp.data.iterators import BucketIterator, DataIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor

from allennlp.nn.util import masked_softmax
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.common.params import Params
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer


TOKENIZER = WordTokenizer()

snip_label_to_idx = {'weather': 0,
 'music': 1,
 'restaurant': 2,
 'book': 5,
 'movie': 4,
 'search': 3,
 'playlist': 6}

class SnipDataset(DatasetReader):
    """
    """
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None):
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}

    def text_to_instance(self, csv_list):
        #
        label = LabelField(snip_label_to_idx[csv_list[0]], skip_indexing=True)
        tokens = [Token(token) for token in csv_list[1].split()]
        sentence_field = TextField(tokens, self.token_indexers)

        fields = {"sentence": sentence_field, "labels": label}
        return Instance(fields)

    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            text_classification_reader = csv.reader(f, delimiter='\t')
            for line in text_classification_reader:
                    yield(self.text_to_instance(line))