import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.tree import export_text
from sklearn.tree import DecisionTreeClassifier
from RUG import RUGClassifier
from Datasets import wdbc_original

randomState = 29232516
maxDepth = 3
maxRMPcalls = 30
ruleLengthCost = True

rhsEps = 0.01

selectedRuleNum = 12

problem = wdbc_original

df = np.array(problem('datasets/'))
X = df[:, 0:-1]
y = df[:, -1]
 
X_train, X_test, y_train, y_test = \
    train_test_split(X, y, random_state=randomState, test_size=0.3)    

RUG = RUGClassifier(max_depth=maxDepth,
                    max_RMP_calls=maxRMPcalls,
                    rule_length_cost=ruleLengthCost,
                    random_state=randomState)
RUG_fit = RUG.fit(X_train, y_train)


DT = DecisionTreeClassifier(max_depth=maxDepth, random_state=randomState)
DT_fit = DT.fit(X_train,y_train)

weights = np.round(RUG.getWeights()[:selectedRuleNum], decimals=2)
accuracies = []
coverages = []

for indx in range(min(selectedRuleNum, RUG.getNumOfRules())):
    print(indx)
    RUG_pred_test = RUG.predict(X_test, range(indx+1))
    accuracies.append(accuracy_score(RUG_pred_test, y_test))
    coverages.append(RUG.getNumOfMissed())
    
accuracies = np.round(accuracies, decimals=3)
coverages = np.round(1.0-(np.array(coverages)/len(y)), decimals=2)*100
coverages = np.round(coverages, decimals=1)
txtmisses = [str(cover)+'%' for cover in coverages]


# Plotting
data = {'Rules': [indx+1 for indx in range(len(accuracies))],
        'Weights': weights,
        'Coverages': coverages,
        'Accuracies': accuracies}

df = pd.DataFrame(data, columns=data.keys())

fig, ax1 = plt.subplots()
sns.color_palette('pastel')

bars = sns.barplot(data=df, x='Rules', y=weights, ax=ax1, 
                   color='green', alpha=0.3)

for indx, txtmiss in enumerate(txtmisses):
    bars.annotate(txtmiss, xy=(indx, weights[indx]+0.02), 
                  color='purple', horizontalalignment='center',
                  fontsize=8)

bars.axhline(1.0, color='lightgray', linestyle='--')
    
ax1.set_ylabel('Rule Weight', color='darkgreen')
ax1.set_ylim([0.0, 1.1])

ax2 = ax1.twinx()

sns.pointplot(data=df, x='Rules', y='Accuracies', ax=ax2)
ax2.set_ylabel('Mean Accuracy', color='darkblue')
ax2.set_ylim(ax1.get_ylim())
ax2.grid(False)


fig.tight_layout()

print(accuracies)
RUG.printRules(range(3))
tree_rules = export_text(DT)
print(tree_rules)
