import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
basepath = 'trained_models/'
basepath_multi = 'multi_attack_trained_models/'

data = np.load('data.npz')
labels = data['label']

all_ids = np.load('200k_2000cluster/minmax_balance_0.5_index.npy')
print(all_ids.shape)
all_labels = labels[all_ids]

unique_mm, counts_mm = np.unique(all_labels, return_counts=True)
print(all_labels.shape)

plt.figure(figsize=(10,7))
plt.grid()
#plt.plot(all_occurences_mm, label = 'Wavelet, cA2')
#plt.plot(mean_mm, '--', label = 'cA2average num.samples per label')
plt.plot(counts_mm, label = 'Raw signal, minmax')
#plt.plot(mean_rnd, '--', label = 'Raw signal average num.samples per label')
#plt.xticks(x, num_samples)
plt.xlabel('Label value')
plt.ylabel('Number of sample within label')
plt.legend()
plt.savefig('Label_Distribution_wavelet_minmax.png')
exit()

test_key = 1733

num_samples = [2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000]
all_mm = []
all_random = []
all_labels = []
all_occurences_random = []
all_occurences_mm = []
mean_mm = []
mean_rnd = []
#'baseline_minmax_wave_0_200000_cA2__16000_mm'
for num_sample in num_samples:
    minmax_path = 'baseline_none_wave_0_200000_cD1__{}_mm/attack_rank.csv'.format(num_sample) #_cA2_0_10000
    #minmax_path = 'baseline_minmax_wave_0_200000_cA2__{}_mm/attack_rank.csv'.format(num_sample) #_cA2_0_10000
    #random_path = 'baseline_random_wave_0_200000_{}/attack_rank.csv'.format(num_sample) #True random path
    random_path = 'baseline_none_wave_0_200000_{}_mm/attack_rank.csv'.format(num_sample)
    #'baseline_none_wave_0_200000_2000_mm'
    #random_ids_path = 'baseline_random_wave_0_200000_{}/all_ids.npy'.format(num_sample) #True random path
    random_ids_path = 'baseline_none_wave_0_200000_{}_mm/all_ids.npy'.format(num_sample)
    #mm_ids_path = 'baseline_none_wave_0_200000_cA2_0_{}/all_ids.npy'.format(num_sample)
    mm_ids_path = 'baseline_none_wave_0_200000_cD1__{}_mm/all_ids.npy'.format(num_sample)
    result_mm = pd.read_csv(os.path.join(basepath_multi, minmax_path))
    result_random = pd.read_csv(os.path.join(basepath, random_path))
    all_mm.append(np.min(result_mm['Attack Mean Rank']))
    all_random.append(np.min(result_random['Attack Mean Rank']))
    random_ids = np.load(os.path.join(basepath, random_ids_path))
    mm_ids = np.load(os.path.join(basepath_multi, mm_ids_path))

    label_random = labels[random_ids]
    label_mm = labels[mm_ids]
    unique_mm, counts_mm = np.unique(label_mm, return_counts=True)
    unique_rnd, counts_rnd = np.unique(label_random, return_counts=True)
    mean_mm.append(np.mean(counts_mm))
    mean_rnd.append(np.mean(counts_rnd))
    occurences_random = np.where(label_random == test_key)
    occurences_mm = np.where(label_mm == test_key)
    all_occurences_random.append(len(occurences_random[0]))
    all_occurences_mm.append(len(occurences_mm[0]))
    #exit()

# Plot min rank

plt.figure(figsize=(10,7))
plt.grid()
plt.plot(all_mm, label = 'Wavelet, cD1')
plt.plot(all_random, label = 'Raw signal')
x = np.arange(len(num_samples))
plt.xticks(x, num_samples)
plt.xlabel('Num. Samples')
plt.ylabel('Mean Attack Rank')
plt.legend()
plt.savefig('Num_sample_wavelet_minmax.png')
plt.clf()
#Plot labels distribution


plt.figure(figsize=(10,7))
plt.grid()
#plt.plot(all_occurences_mm, label = 'Wavelet, cA2')
#plt.plot(mean_mm, '--', label = 'cA2average num.samples per label')
plt.plot(counts_rnd, label = 'Raw signal, minmax')
#plt.plot(mean_rnd, '--', label = 'Raw signal average num.samples per label')
#plt.xticks(x, num_samples)
plt.xlabel('Label value')
plt.ylabel('Number of sample within label')
plt.legend()
plt.savefig('Label_Distribution_wavelet_minmax.png')
