import sys
sys.path.append("../SEV/")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pacmap
from FCMCluster import FuzzyCMeans
from Encoder import DataEncoder
from data_loader import data_loader
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier

# load the dataset
datasets = ["adult","german","compas","diabetes","fico","mimic","headline1","headline2","headline3","headline_total"]
n_neighbors = [None,20,20,20,None,None,None,None,None,None]
n_class = [7,3,5,4,4,4,3,2,3,2]
ms = [1.01,1.01,2,1.01,3,1.01,1.01,1.01,1.01,1.01]

# select fico dataset
dataset = "fico"
n_neighbors = None
n_class = 4

# prepare the figure for the results
fig,ax = plt.subplots(1,3,figsize=(15,5))

X, y, X_neg = data_loader(dataset)
encoder = DataEncoder(standard=True)
encoder.fit(X_neg)
encoded_X = encoder.transform(X)
encoded_X_neg = encoder.transform(X_neg)
print("Working on the dataset {}".format(dataset))
# do the embedding
pacmapper = pacmap.PaCMAP(n_components=2, n_neighbors=None, MN_ratio=1, FP_ratio=2.0)
pacmapper.fit(encoded_X_neg,init="pca")
X_embedded = pacmapper.transform(encoded_X,encoded_X_neg)
X_embedded_neg = X_embedded[y==0]
X_embedded_pos = X_embedded[y==1]

model = GradientBoostingClassifier(n_estimators=200,max_depth=1,random_state=42)
model.fit(encoded_X,y)
# do the clustering
fcm_original = FuzzyCMeans(model,n_clusters=n_class,m=1.000001)
fcm = FuzzyCMeans(model,n_clusters=4,m=3)
fcm.fit(X_embedded_neg,encoded_X_neg)
labels = fcm.predict(X_embedded_neg,encoded_X_neg)
predicted_labels = model.predict(encoded_X_neg)

print("The count of each label is")
print(pd.Series(labels).value_counts())

X_med_embedded = []

positive_labels = []

# check if all the X_embedding_neg are predicted as negative
for i in range(n_class):
    X_med = np.median(encoded_X_neg[(labels==i)],axis=0)
    X_med_embedded.append(pacmapper.transform(X_med.reshape(1,-1),encoded_X_neg))
    if model.predict(X_med.reshape(1,-1)) != 0:
        print("The median of the cluster {} is predicted as positive".format(i))
        print(model.predict_proba(X_med.reshape(1,-1)))
        print("The median is {}".format(X_med))
        positive_labels.append(i)

X_med_embedded = np.array(X_med_embedded).reshape(-1,2)


# # plot the results
# plt.figure()
# from matplotlib import colorbar
# from matplotlib import cm
# cmap = cm.get_cmap("Pastel1")
# colors = cmap(np.linspace(0,0.8,n_class[ind]))
# plt.scatter(X_embedded_pos[:,0],X_embedded_pos[:,1],c="gray",s=10,alpha=0.1)
# for i in range(n_class[ind]):
#     plt.scatter(X_embedded_neg[labels==i,0],X_embedded_neg[labels==i,1],label="Class {}".format(i),c=colors[i],s=10,alpha=0.7)
# # plt.scatter(X_embedded_neg[:,0],X_embedded_neg[:,1],c=labels,cmap="Pastel1",s=10,alpha=0.5)
# plt.title("The embedding of the dataset {}".format(dataset))
# plt.legend()
# plt.savefig("../Results/figures/{}_embedding.png".format(dataset))
plt.subplot(1,3,1)
y_predicted_proba = model.predict_proba(encoded_X_neg)[:,1]
plt.scatter(X_embedded_neg[:,0],X_embedded_neg[:,1],c=y_predicted_proba,s=10,alpha=0.1,cmap="RdBu_r")
plt.xticks([])
plt.yticks([])
# plot the colormap bar
cbar = plt.colorbar()
# add the colormap bar name as predicted probability
plt.title("The predicted probability of each query")


# plot the results
plt.subplot(1,3,2)
from matplotlib import colorbar
from matplotlib import cm
cmap = cm.get_cmap("Pastel1")
colors = cmap(np.linspace(0,0.7,n_class))
plt.scatter(X_embedded_pos[:,0],X_embedded_pos[:,1],c="gray",s=10,alpha=0.1)
for i in range(n_class):
    plt.scatter(X_embedded_neg[labels==i,0],X_embedded_neg[labels==i,1],c=colors[i],s=10,alpha=0.7)
plt.scatter(X_med_embedded[:,0],X_med_embedded[:,1],marker="*",s=100,c="blue")
for positive_label in positive_labels:
    plt.scatter(X_med_embedded[positive_label,0],X_med_embedded[positive_label,1],marker="*",s=100,c="red")
# remove the x,y axis
plt.xticks([])
plt.yticks([])
plt.title("Original Soft K-Means Clustering")

fcm_original = FuzzyCMeans(model,n_clusters=n_class,m=100)
fcm_original.fit(X_embedded_neg,encoded_X_neg)
labels = fcm_original.predict(X_embedded_neg,encoded_X_neg)
predicted_labels = model.predict(encoded_X_neg)

print("The count of each label is")
print(pd.Series(labels).value_counts())

X_med_embedded = []

# check if all the X_embedding_neg are predicted as negative
for i in range(n_class):
    X_med = np.median(encoded_X_neg[(labels==i)&(predicted_labels==0)],axis=0)
    X_med_embedded.append(pacmapper.transform(X_med.reshape(1,-1),encoded_X_neg))
    if model.predict(X_med.reshape(1,-1)) != 0:
        print("The median of the cluster {} is predicted as positive".format(i))
        print(model.predict_proba(X_med.reshape(1,-1)))
        print("The median is {}".format(X_med))

X_med_embedded = np.array(X_med_embedded).reshape(-1,2)

plt.subplot(1,3,3)
# plot the results
from matplotlib import colorbar
from matplotlib import cm
cmap = cm.get_cmap("Pastel1")
colors = cmap(np.linspace(0,0.7,n_class))
plt.scatter(X_embedded_pos[:,0],X_embedded_pos[:,1],c="gray",s=10,alpha=0.1)
for i in range(n_class):
    plt.scatter(X_embedded_neg[labels==i,0],X_embedded_neg[labels==i,1],c=colors[i],s=10,alpha=0.7)
plt.scatter(X_med_embedded[:,0],X_med_embedded[:,1],marker="*",s=100,c="blue")
plt.xticks([])
plt.yticks([])
plt.title("Score-based Soft K-Means Clustering")
# plt.legend()
plt.savefig("../Results/figures/FICO.png")