import numpy as np
import pickle
import os
import pandas as pd
from clustering import Clustering
clustering_sim = Clustering()

# LOAD DATASET (CAN BE CHANGED FOR ANY CUSTOM DATASET)
traj_dataset_dir = '../failure_clustering_datasets/driving/nexar/minus_1000ms_cropped_top85_bottom60_imgs/'
labels = pd.read_csv('../failure_clustering_datasets/driving/nexar/selected_samples_with_adjusted_targets.csv')

TP = 0
FP = 0
TN = 0
FN = 0

failure_states = []
success_states = []
pred_failure_states = []
pred_success_states = []
overall_actual = []
overall_preds = []
actual_success_count = 0
actual_failure_count = 0
pred_success_count = 0
pred_failure_count = 0
cluster_output = []

files = sorted(os.listdir(traj_dataset_dir))
for i in range(len(files)):

    print('i: ', i)
    fail_images_dir_i = traj_dataset_dir + files[i]
    print('dir: ', fail_images_dir_i)

    idx = len(os.listdir(fail_images_dir_i))
    # print('no of images: ', idx)
        
    if labels.loc[labels['id'] == int(files[i]), 'target_time_minus_1000ms'].values[0] == 0:
        actual_success_count += 1
        curr_actual = 0
        print('ACTUAL SUCCESS')
    else:
        actual_failure_count += 1
        curr_actual = 1
        print('ACTUAL FAILURE')

    overall_actual.append(curr_actual)

    pred = clustering_sim.cluster(fail_images_dir_i, idx - 1)
    print('pred: ', pred)
    cluster_output.append(str(pred))

    if pred is not None and pred != 'SAFE':
        pred = 1
        pred_failure_count += 1
        curr_pred = 1
        print('PREDICTED FAILURE')
    else:
        pred = 0
        pred_success_count += 1
        curr_pred = 0
        print('PREDICTED SUCCESS')

    overall_preds.append(pred)
    print(' ')
    if curr_actual == 1 and curr_pred == 1:
        TP += 1
    elif curr_actual == 1 and curr_pred == 0:
        FN += 1
    elif curr_actual == 0 and curr_pred == 1:
        FP += 1
    elif curr_actual == 0 and curr_pred == 0:
        TN += 1

curr_dict = {}
if len(overall_preds) > 0:
    curr_dict['preds'] = np.array(overall_preds)
    curr_dict['actual'] = np.array(overall_actual)
    curr_dict['cluster_output'] = np.array(cluster_output)

if not os.path.exists('../results/monitoring/driving/'):
    os.makedirs('../results/monitoring/driving/')
with open('../results/monitoring/driving/clustering.pkl', 'wb') as file:
    pickle.dump(curr_dict, file)

print('True Positive: ', TP)
print('True Negative: ', TN)
print('False Positive: ', FP)
print('False Negative: ', FN)
print('True Positive Rate: ', TP / (TP + FN))
print('True Negative Rate: ', TN / (TN + FP))
print('False Positive Rate: ', FP / (FP + TN))
print('False Negative Rate: ', FN / (FN + TP))