import sys, tempfile
import numpy as np
import pandas as pd
from lng.lca.lc_anc import lca
from lng.L2SCA.analyzeText import sca
from datasets import load_from_disk
from nltk.stem import WordNetLemmatizer
from nltk import word_tokenize, pos_tag
from sklearn.preprocessing import StandardScaler

dataset = sys.argv[1]
data = load_from_disk(dataset)

lemmatizer = WordNetLemmatizer()
from nltk.corpus import wordnet
tag_dict = {
        "J": wordnet.ADJ,
        "N": wordnet.NOUN,
        "V": wordnet.VERB,
        "R": wordnet.ADV
        }

def process_l(sent1, sent2):
    sent1 = pos_tag(word_tokenize(sent1))
    sent2 = pos_tag(word_tokenize(sent2))
    sent1 = " ".join(["{}_{}".format(
        lemmatizer.lemmatize(word, tag_dict[tag[0]] if tag[0] in tag_dict else 'n'),
        tag)
        for word, tag in sent1])
    sent2 = " ".join(["{}_{}".format(
        lemmatizer.lemmatize(word, tag_dict[tag[0]] if tag[0] in tag_dict else 'n'),
        tag)
        for word, tag in sent2])

    lcares = lca([sent1, sent2])
    return {'lca': lcares}

def process_s(sent1, sent2):
    tmp_file = 'tmp/%s.txt'%next(tempfile._get_candidate_names())
    with open(tmp_file, 'w') as f:
        f.write(sent1 + '\n\n' + sent2)
    scares = sca(tmp_file)
    return {'sca': scares}

data = data.map(process_s, input_columns = ['sentence1', 'sentence2'], num_proc = 128)
data = data.map(process_l, input_columns = ['sentence1', 'sentence2'])

for split in ['train', 'dev', 'test']:
    subdata = data[split]
    df = pd.DataFrame(subdata['lca'])
    df.fillna(method='ffill', inplace=True)
    df_list = df.values.tolist()
    subdata = subdata.map(lambda x,i: {'lca': df_list[i]}, with_indices = True)
    if split == 'train':
        lca_scaler = StandardScaler()
        sca_scaler = StandardScaler()
        lca_scaler.fit(subdata['lca'])
        sca_scaler.fit(subdata['sca'])
    subdata = subdata.map(lambda x: {'lca_norm': lca_scaler.transform([x['lca']])[0],
                                     'sca_norm': sca_scaler.transform([x['sca']])[0]})
    subdata = subdata.map(lambda x: {'lca_sum': sum(x['lca_norm']), 'sca_sum': sum(x['sca_norm']),
                                    'lns_sum': sum(x['lca_norm']) + sum(x['sca_norm'])})
    data[split] = subdata

thresh11 = np.percentile(data['train']['lca_sum'], 33)
thresh12 = np.percentile(data['train']['lca_sum'], 66)
thresh21 = np.percentile(data['train']['sca_sum'], 33)
thresh22 = np.percentile(data['train']['sca_sum'], 66)
thresh31 = np.percentile(data['train']['lns_sum'], 33)
thresh32 = np.percentile(data['train']['lns_sum'], 66)

def classes(x):
    if x['lca_sum'] < thresh11:
        lca_cls = 0
    elif x['lca_sum'] < thresh12:
        lca_cls = 1
    else:
        lca_cls = 2

    if x['sca_sum'] < thresh21:
        sca_cls = 0
    elif x['sca_sum'] < thresh22:
        sca_cls = 1
    else:
        sca_cls = 2

    if x['lns_sum'] < thresh31:
        lns_cls = 0
    elif x['lns_sum'] < thresh32:
        lns_cls = 1
    else:
        lns_cls = 2
   
    return {'lca_class': lca_cls, 'sca_class': sca_cls, 'lns_class': lns_cls}

for split in ['train', 'dev', 'test']:
    subdata = data[split]
    subdata = subdata.map(classes)
    data[split] = subdata

data.save_to_disk(dataset + '_lng')
