import pickle
from minicons import cwe

from tqdm import tqdm

import csv

from torch.utils.data import DataLoader

from collections import defaultdict

bert = cwe.CWE("bert-base-uncased", "cuda:0")

sense_data = []
words = defaultdict(lambda: len(words))
sents = defaultdict(lambda: len(sents))

with open("../sent_data/all_sent_data.pkl", "rb") as f:
    semeval = pickle.load(f)

for line in semeval:
    dataset_id, sent_id, sentence, word, position, sense = line
    sent_id = sents[sentence]
    word_id = words[word]
    sense_data.append((dataset_id, sent_id, sentence, word, position, sense))

with open("../data/multi_sense.csv", "r") as f:
    reader = csv.DictReader(f)
    for line in reader:
        if line['pos'] != 'NNP' and line['pos'] != 'RB':
            if line['context'][-1] != ".":
                sentence = line['context'] + " ."
            else:
                sentence = line['context']
            sent_id = sents[sentence]
            word_id = words[line['word']]
            position = int(line['index'])
            sense_data.append(("semcor", sent_id, sentence , line['word'], position, line['lex_sense']))


sense_dl = DataLoader(sense_data, batch_size = 100)

embedding_data = []
for batch in tqdm(sense_dl):
    _, _, sentence, word, position, sense = batch
    position_tuple = [(i, i+1) for i in position.tolist()]
    representation = bert.extract_representation(list(zip(sentence, position_tuple)))
    embedding_data.extend(list(zip(position.tolist(), word, sense, representation.detach().cpu())))

## Do something with the embedding data.