import numpy as np
import pickle
import os
from clustering import Clustering
clustering_sim = Clustering()

states_dir = '../failure_clustering_datasets/waypointnav/states/area1/patch0'
thetas = [-3.14159265, -2.51327412, -1.88495559, -1.25663706, -0.62831853, 0,  0.62831853,  1.25663706,  1.88495559,  2.51327412]
traj_dataset_dir = '../failure_clustering_datasets/waypointnav/area1/patch0'

overall_actual_failure_count = 0
overall_actual_success_count = 0
overall_pred_failure_count = 0
overall_pred_success_count = 0

TP = 0
FP = 0
TN = 0
FN = 0

for theta in thetas:
    print('theta: ', theta)
    fail_images_dir = traj_dataset_dir + str(theta)[:5] + '/'
    states_file = f'{states_dir}/random_states{str(theta)[:5]}.pkl'
    with open(states_file, 'rb') as f:
        data = pickle.load(f)
        states = data['states']
        fail_states = states

    overall_preds = []
    overall_actual = []
    overall_ids = []

    actual_success_count = 0
    actual_failure_count = 0
    pred_success_count = 0
    pred_failure_count = 0

    for i in range(len(fail_states)):
        fail_images_dir_i = fail_images_dir + str(i) + '/images/'
        print(fail_images_dir_i)
        
        idx = len(os.listdir(fail_images_dir_i))
        print('no of images: ', idx)

        if idx > 5:
            overall_ids.append(i)
            metadata_i = fail_images_dir_i[:-7] + 'rgb_resnet50_nn_waypoint_simulator/trajectories/metadata.pkl'
            if os.path.exists(metadata_i):
                with open(metadata_i, 'rb') as f:
                    traj_i = pickle.load(f)

                if traj_i['episode_type_string'][0] == 'Success':
                    actual_success_count += 1
                    curr_actual = 0
                    print('ACTUAL SUCCESS')

                else:
                    actual_failure_count += 1
                    curr_actual = 1
                    print('ACTUAL FAILURE')

            else:
                actual_failure_count += 1
                curr_actual = 1
                print('ACTUAL FAILURE')
            overall_actual.append(curr_actual)

            pred1 = clustering_sim.cluster(fail_images_dir_i, idx - 6)
            pred2 = clustering_sim.cluster(fail_images_dir_i, idx - 5)
            pred3 = clustering_sim.cluster(fail_images_dir_i, idx - 4)
            print('pred: ', pred1, ',', pred2, ',', pred3)
            overall_preds.append([pred1, pred2, pred3])
            if pred1 is not None and pred1 != 'SAFE' and pred2 is not None and pred2 != 'SAFE' and pred3 is not None and pred3 != 'SAFE':
                pred = 1
                pred_failure_count += 1
                curr_pred = 1
                print('PREDICTED FAILURE')
            else:
                pred = 0
                pred_success_count += 1
                curr_pred = 0
                print('PREDICTED SUCCESS')

            print(' ')

            if curr_actual == 1 and curr_pred == 1:
                TP += 1
            elif curr_actual == 1 and curr_pred == 0:
                FN += 1
            elif curr_actual == 0 and curr_pred == 1:
                FP += 1
            elif curr_actual == 0 and curr_pred == 0:
                TN += 1

    curr_dict = {}
    if len(overall_ids) > 0:
        curr_dict['ids'] = np.vstack(overall_ids)
    if len(overall_preds) > 0:
        curr_dict['preds'] = np.vstack(overall_preds)
    if len(overall_actual) > 0:
        curr_dict['actual'] = np.vstack(overall_actual)
    if not os.path.exists('../results/monitoring/waypointnav/area1/patch/'):
        os.makedirs('../results/monitoring/waypointnav/area1/patch/')
    with open('../results/monitoring/waypointnav/area1/patch/clustering_' + str(theta)[:5] + '.pkl', 'wb') as fp:
        pickle.dump(curr_dict, fp)

    overall_actual_failure_count += actual_failure_count
    overall_actual_success_count += actual_success_count
    overall_pred_failure_count += pred_failure_count
    overall_pred_success_count += pred_success_count

    print('Current True Positive: ', TP)
    print('Current True Negative: ', TN)
    print('Current False Positive: ', FP)
    print('Current False Negative: ', FN)
    print('Current True Positive Rate: ', TP / (TP + FN))
    print('Current True Negative Rate: ', TN / (TN + FP))
    print('Current False Positive Rate: ', FP / (FP + TN))
    print('Current False Negative Rate: ', FN / (FN + TP))
    print('F1 Score: ', 2 * TP / (2 * TP + FP + FN))

    print(' ')
    print(' ')

print('Final True Positive: ', TP)
print('Final True Negative: ', TN)
print('Final False Positive: ', FP)
print('Final False Negative: ', FN)
print('True Positive Rate: ', TP / (TP + FN))
print('True Negative Rate: ', TN / (TN + FP))
print('False Positive Rate: ', FP / (FP + TN))
print('False Negative Rate: ', FN / (FN + TP))
print('F1 Score: ', 2 * TP / (2 * TP + FP + FN))