import torch
import numpy as np
from utils import *
from scipy.stats import spearmanr

def similarity_dataset(dataset_name: str) -> list:
    """
    word/text similarity dataset format (tab/space separated):
    word1  word2  score
    """
    dataset = []
    with open(f"word_similarity_datasets/{dataset_name}.txt","r",encoding="utf-8") as f:
        for line in f:
            if "\t" in line:
                items = line.strip("\n").lower().split("\t")
            else:
                items = line.strip("\n").lower().split(" ")
            dataset.append([items[0], items[1], float(items[2])])
    return dataset

def evaluate_similarity(sim_data, model, data, normalize=True, ignore_oov=True):
    sim_label = np.array([row[2] for row in sim_data])
    sim_data = np.array([ list(map(data.word_to_idx,row[0:2]))
                                  for row in sim_data])
    # print(sim_data)
    total = sim_data.shape[0]
    oov_mask = (sim_data==0).any(1)
    iv_mask  = (sim_data!=0).all(1)
    oov_count = np.count_nonzero(oov_mask)
    if ignore_oov:
        sim_data = sim_data[iv_mask,:]
        sim_label = sim_label[iv_mask]
    print(f"{oov_count}/{total} ({oov_count/total*100:.1f}%) out of vocab examples")
    sim_data = opt_cuda(torch.tensor(sim_data))
    sim_predict = model.word_similarity(sim_data[:,0],sim_data[:,1],normalize=True)
    if not ignore_oov:
        sim_predict[oov_mask] = 0 # need to figure out what value to subsitute here
    return spearmanr(sim_label, sim_predict)[0]
