from data_reader import data_reader
from algorithms import *
from sys import argv
import math
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPRegressor
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor

# Get dataset of choice, desired accuracy and confidence, and the extra number of samples for the baselines
dataset, epsilon, delta, M = argv[1], float(argv[2]), float(argv[3]), int(argv[4])

# Fetch data
(D_1, D_2) = data_reader(dataset) 

# Choose feature weights for the distance functions
d = D_1.shape[1]
a = list(np.random.uniform(0.0, 1.0, d))
b = list(np.random.uniform(0.0, 1.0, d))
s = list(np.random.uniform(-0.01, 0.01, d))

# Sets of samples for our algorithms
N = math.ceil(np.log(1 / (delta ** 2)) / delta)
S_1 = D_1[:N]
S_2 = D_2[:N]

# Compute representatives for the query-minimization algorithms
R_1 = Cluster(S_1, epsilon, 2)
R_2 = Cluster(S_2, epsilon, 2)

# Train all baselines on M + N examples
mlp = MLPRegressor(hidden_layer_sizes = (32, 32, 32), activation = 'relu', solver = 'adam', max_iter = 500)
rf = RandomForestRegressor(n_estimators = 200)
xgb = XGBRegressor(n_estimators = 200)

X = np.zeros((N + M, 2 * d))
y = np.zeros(N + M)
for i in range(N + M):
    x_1 = random.choice(S_1)
    x_2 = random.choice(S_2)
    X[i] = np.concatenate((x_1, x_2)) 
    y[i] = across_distance(x_1, x_2, a, b, s)
    
mlp.fit(X, y)
rf.fit(X, y)
xgb.fit(X, y)

# Error statistic lists
rel_err_naive, abs_err_naive = [], []
rel_err_cluster, abs_err_cluster = [], []
rel_err_mlp, abs_err_mlp = [], []
rel_err_rf, abs_err_rf = [], []
rel_err_xgb, abs_err_xgb = [], []

# Testing phase
nQ = set()
cQ = set()

for i in range(min(1000, min(len(D_1), len(D_2)) - N)):
        
    x = D_1[N + i]
    y = D_2[N + i]
    z = np.concatenate((x,y))
    
    pred_naive = Predictor(S_1, S_2, x, y, a, b, s, nQ)
    pred_cluster = Predictor(R_1, R_2, x, y, a, b, s, cQ)
    pred_mlp = mlp.predict(z.reshape(1,-1))[0]
    pred_rf = rf.predict(z.reshape(1,-1))[0]
    pred_xgb = xgb.predict(z.reshape(1,-1))[0]
    
    real = across_distance(x , y, a, b, s)
    
    # Relative Errors
    rel_err_naive.append(abs(pred_naive - real) * 100 / real)
    rel_err_cluster.append(abs(pred_cluster - real) * 100 / real)
    rel_err_mlp.append(abs(pred_mlp - real) * 100 / real)
    rel_err_rf.append(abs(pred_rf - real) * 100 / real)
    rel_err_xgb.append(abs(pred_xgb - real) * 100 / real)
    
    # Absolute errors in terms of epsilon
    abs_err_naive.append(abs(pred_naive - real) / epsilon)
    abs_err_cluster.append(abs(pred_cluster - real) / epsilon)
    abs_err_mlp.append(abs(pred_mlp - real) / epsilon)
    abs_err_rf.append(abs(pred_rf - real) / epsilon)
    abs_err_xgb.append(abs(pred_xgb - real) / epsilon)
    
# Plot Error Distributions
bins = [0, 0.5, 1, 2, 5, 10, 20, 30, 100]
t = np.arange(8) 

rel_errs = [rel_err_naive, rel_err_cluster, rel_err_mlp, rel_err_rf, rel_err_xgb]
abs_errs = [abs_err_naive, abs_err_cluster, abs_err_mlp, abs_err_rf, abs_err_xgb]

labels = ["Naive", "Cluster", "MLP", "RF", "XGB"]
colors = ['blue', 'green', 'red', 'purple', 'orange']
w_ofst = - 0.3

plt.rcParams['figure.figsize'] = (34, 18)    
plt.rcParams.update({'font.size': 30}) 

for i in range(len(rel_errs)):
    r = rel_errs[i]
    hist, _ = np.histogram(r, bins = bins, density = False)
    hist = hist / len(r)
    plt.bar(t + w_ofst, np.cumsum(hist), width=0.15, color = colors[i], label = labels[i], edgecolor = 'black')
    w_ofst += 0.15    
    
plt.xticks(t, bins[1:])
plt.xlabel("Relative Error Value")
plt.ylabel("Empirical Probability")
plt.title("The CDF of the Relative Error")
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.savefig("Plots/RelError-" + dataset + ".png")

plt.cla()
    
w_ofst = - 0.3
for i in range(len(abs_errs)):
    r = abs_errs[i]
    hist, _ = np.histogram(r, bins = bins, density = False)
    hist = hist / len(r)
    plt.bar(t + w_ofst, np.cumsum(hist), width=0.15, color = colors[i], label = labels[i], edgecolor = 'black')
    w_ofst += 0.15    
    
plt.xticks(t, bins[1:])
plt.xlabel("Value of (Absolute Error / epsilon)")
plt.ylabel("Empirical Probability")
plt.title("The CDF of (Absolute Error / epsilon) ")
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.savefig("Plots/AbsError-" + dataset+ ".png")

# Save statistics
f = open(dataset + "-statistics.txt", "w")
f.write("Naive online queries: " + str(len(nQ)) + "\n")
f.write("Cluster online queries: " + str(len(cQ)) + "\n")
f.write("Worst case query improvement percentage: " + str((1 - len(R_1)*len(R_2)/(N**2))*100) +"\n")
f.write("Naive statistics: " + str((sum(rel_err_naive) / len(rel_err_naive), np.std(rel_err_naive))) + " --- " + str((sum(abs_err_naive) / len(abs_err_naive), np.std(abs_err_naive))) + "\n")
f.write("Cluster statistics: " + str((sum(rel_err_cluster) / len(rel_err_cluster), np.std(rel_err_cluster))) + " --- " + str((sum(abs_err_cluster) / len(abs_err_cluster), np.std(abs_err_cluster))) + "\n")
f.write("MLP statistics: " + str((sum(rel_err_mlp) / len(rel_err_mlp), np.std(rel_err_mlp))) + " --- " + str((sum(abs_err_mlp) / len(abs_err_mlp), np.std(abs_err_mlp))) + "\n")
f.write("RF statistics: " + str((sum(rel_err_rf) / len(rel_err_rf), np.std(rel_err_rf))) + " --- " + str((sum(abs_err_rf) / len(abs_err_rf), np.std(abs_err_rf))) + "\n")
f.write("XGB statistics: " + str((sum(rel_err_xgb) / len(rel_err_xgb), np.std(rel_err_xgb))) + " --- " + str((sum(abs_err_xgb) / len(abs_err_xgb), np.std(abs_err_xgb))) + "\n")
f.close()
