import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
basepath = 'trained_models/'

data = np.load('data.npz')
labels = data['label']

test_key = 1733

num_samples = [2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000]
all_mm = []
all_random = []
all_labels = []
all_occurences_random = []
all_occurences_mm = []
count_mm = []
count_rnd = []

for num_sample in num_samples:
    minmax_path = 'baseline_none_wave_0_200000_{}_mm/attack_rank.csv'.format(num_sample)
    #random_path = 'baseline_none_wave_0_200000_{}_00cluster/attack_rank.csv'.format(num_sample)
    random_path = 'baseline_random_wave_0_200000_{}/attack_rank.csv'.format(num_sample)
    random_ids_path = 'baseline_random_wave_0_200000_{}/all_ids.npy'.format(num_sample)
    mm_ids_path = 'baseline_none_wave_0_200000_{}_mm/all_ids.npy'.format(num_sample)
    #random_ids_path = 'baseline_none_wave_0_200000_{}_00cluster/all_ids.npy'.format(num_sample)
    result_mm = pd.read_csv(os.path.join(basepath, minmax_path))
    result_random = pd.read_csv(os.path.join(basepath, random_path))
    all_mm.append(np.min(result_mm['Attack Mean Rank']))
    all_random.append(np.min(result_random['Attack Mean Rank']))
    random_ids = np.load(os.path.join(basepath, random_ids_path))
    mm_ids = np.load(os.path.join(basepath, mm_ids_path))

    label_random = labels[random_ids]
    label_mm = labels[mm_ids]
    unique_mm, counts_mm = np.unique(label_mm, return_counts=True)
    unique_rnd, counts_rnd = np.unique(label_random, return_counts=True)
    count_mm.append(counts_mm)
    count_rnd.append(counts_rnd)
    occurences_random = np.where(label_random == test_key)
    occurences_mm = np.where(label_mm == test_key)
    all_occurences_random.append(len(occurences_random[0]))
    all_occurences_mm.append(len(occurences_mm[0]))
    #exit()

def get_mean(input_list):
    res = []
    for inp in input_list:
        res.append(np.mean(inp))

    return res

def get_min(input_list):
    res = []
    for inp in input_list:
        res.append(np.min(inp))

    return res

def get_max(input_list):
    res = []
    for inp in input_list:
        res.append(np.max(inp))

    return res

def get_std(input_list):
    res = []
    for inp in input_list:
        res.append(np.std(inp))

    return res
# Plot min rank

plt.figure(figsize=(10,7))
plt.plot(all_mm, label = 'min max sampling')
plt.plot(all_random, label = 'random sampling')
x = np.arange(len(num_samples))
plt.xticks(x, num_samples)
plt.xlabel('Num. Samples')
plt.ylabel('Mean Attack Rank')
plt.legend()
plt.grid()
plt.savefig('Num_sample_mm_rnd.png')
plt.clf()
#Plot labels distribution

plt.figure(figsize=(10,7))
plt.plot(all_occurences_mm, label = 'min max sampling')
plt.grid()
#plt.plot(get_mean(count_mm), '--', label = 'min max average num.samples per label')
#plt.plot(get_min(count_mm), '--', label = 'min max min num.samples per label')
#plt.plot(get_max(count_mm), '--', label = 'min max max num.samples per label')
#plt.plot(get_std(count_mm), '--', label = 'min max std num.samples per label')
plt.plot(all_occurences_random, label = 'min max label sampling')
#plt.plot(get_mean(count_rnd), '--', label = 'random average num.samples per label')
#plt.plot(get_min(count_rnd), '--', label = 'random min num.samples per label')
#plt.plot(get_max(count_rnd), '--', label = 'random max num.samples per label')
#plt.plot(get_std(count_rnd), '--', label = 'random std num.samples per label')
plt.xticks(x, num_samples)
plt.xlabel('Num. Samples')
plt.ylabel('Number of sample with same label as the attack set')
plt.legend()
plt.savefig('Label_Distribution_Balance.png')
