import os
import torch
import pickle
import csv
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

def emb_data_load(layer, type_):
    l = int(layer)
    os.chdir('/embedding_dir_path/ms_embs/bert/')
    if type_ == 'original':
        emb_data = torch.load('original/original_all_'+str(l)+'.pt')
    elif type_ == 'laser':
        emb_data = torch.load('laser/laser_all_'+str(l)+'.pt')
    return emb_data

def sense_data_load():
    sense_data = []
    with open("/makesense_dir_path/data/sense_metadata.csv", "r") as f:
        reader = csv.reader(f)
        next(f)
        for line in reader:
            sense_data.append(line)
    sense_dl = DataLoader(sense_data, batch_size = 100)

    data = []
    for batch in tqdm(sense_dl):
        id_, _, sentence, word, position, sense = batch
        position = [int(item) for item in position]
        position_tuple = [(i, i+1) for i in position]
        data.extend(list(zip(id_, position, word, sense)))

    torch.cuda.empty_cache()
    
    return data