import argparse
import json
import pickle

import numpy as np

from collections import defaultdict
from tqdm import tqdm
import random
import os
import codecs

parser = argparse.ArgumentParser()
parser.add_argument("--train-vocab-file", help="numpy train vocabulary file")
parser.add_argument("--train-point-file", help="numpy train vocabulary representations")
parser.add_argument("--clustering-file", help="clustering results file")
parser.add_argument("--output-path", help="output path for model and results file")
parser.add_argument("--test-vocab-file", help="numpy test vocabulary file")
parser.add_argument("--test-point-file", help="numpy test vocabulary representations")
parser.add_argument("--threshold", help="probability threshold")
parser.add_argument("--delimeter", default='|||', help="delimeter (default: '|||')")

args = parser.parse_args()

tokens = {word:idx for idx,word in enumerate(np.load(args.train_vocab_file))]
representations = np.load(args.train_points_file)

print("Loading clustering file")
clustering = defaultdict(list)
with open(args.clustering_file, "r") as fp:
	for line in fp:
		_s = line.strip().split(args.delimeter)
		cluster_id = _s[-1]
		token = "|||".join(_s[:-1])
		clustering[cluster_id].append(token.replace("\\\\","\\").encode("ascii").decode("unicode-escape").encode("utf16", 'surrogatepass'))

assert len(tokens) == sum([len(v) for v in clustering.values()]), f"Mismatch in dataset ({len(tokens)}) and clustering file ({sum([len(v) for v in clustering.items()])})"
assert len(tokens) == representations.shape[0], "Mismatch in dataset and representations"


X_train = []
Y_train = []
tokens_train = []
for cluster_id in tqdm(clustering):
	cluster_representations = representations[[tokens[token] for token in clustering[cluster_id]], :]
	cluster_tokens = [token for token in clustering[cluster_id]] 
	for i in range(len(cluster_representations)):
		X_train.append(cluster_representations[i,:])
		tokens_train.append(cluster_tokens[i])
		Y_train.append(cluster_id)

from sklearn import metrics	
print("Starting Logistic Regression")	
from sklearn.linear_model import LogisticRegression
model = LogisticRegression(random_state=42,multi_class='multinomial', solver='lbfgs')
model.fit(X_train, Y_train)
with open(args.output_path+'/'+args.output_model, "wb") as fp:
	pickle.dump(model, fp)


test_words = np.load(args.test_vocab_file)
X_test = np.load(args.test_point_file)

print("Loading data: "+test_vocab_file)

print("Predicting Logistic Regression: "+test_vocab_file)
probabilities = model.predict_proba(X_test)


with open(args.output_path+'/probabilities', "wb") as fp:
	pickle.dump(probabilities, fp)

threshold = float(args.threshold)
selected_pred = []
selected = []
discarded = []
max_probs = np.max(probabilities,axis=1)
max_ids = np.argmax(probabilities,axis=1)
	
selected_ids = max_probs>=threshold
selected_max_probs = max_probs[selected_ids]
selected_max_ids = max_ids[selected_ids]
	
selected_pred = model.classes_[selected_max_ids]

discarded_ids = selected_ids[selected_ids==False]
discarded_percent = (100.0*discarded_ids.shape[0])/probabilities.shape[0]

selected_words = test_words[selected_ids]
print("Percentage words discarded = "+str(discarded_percent)+"%")

target = codecs.open(args.output_path+'/results.txt','w',encoding='utf-8')
for i in range(selected_pred.shape[0]):
	word = selected_words[i]
	pred_class = selected_pred[i]
	target.write(word+args.delimeter+pred_class+"\n")
target.close()
