import pandas as pd
import os
import argparse
import numpy as np

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--num_sample', type=int)
    parser.add_argument('--type', type=int, default=0, help='mean:0, std:1, abs:2,')
    parser.add_argument('--minmax', type=int, default=0, help='mean:0, std:1, abs:2,')
    parser.add_argument('--random', type=int, default=0, help='mean:0, std:1, abs:2,')
    parser.add_argument('--num_iteration', type=int, default=10, help='mean:0, std:1, abs:2,')
    parser.add_argument('--all_path', type=int, default=1, help='mean:0, std:1, abs:2,')
    return parser

parser = parse_arguments()
args = parser.parse_args()

res_df = []
all_on_dict = []
all_cross_dict = []
for i in range(1, args.num_iteration + 1):
    minmax_name = ''
    if args.minmax == 1:
        minmax_name = '_minmax'
    fname = 'multi_attack_trained_models/test_rebaseline{}_baseline_none_wave_0_200000_{}_{}_1_0_lr_1e-05/'.format(minmax_name,args.num_sample,i)
    if args.random == 1:
        fname = 'multi_attack_trained_models/test_random_baseline_random_wave_0_200000_{}_{}_1_0_lr_1e-05/'.format(args.num_sample,i)
    print(fname)
    on_device_fpath = os.path.join(fname,'mean_rank_full_attack_multi_data_300key_100.csv')
    cross_device_fpath = os.path.join(fname,'mean_rank_full_attack_300key_100_cross_Device3_OldDev_100.csv')
    on_df = pd.read_csv(on_device_fpath)
    cross_df = pd.read_csv(cross_device_fpath)
    on_dict = {}
    cross_dict = {}
    on_dict['On No'] = i
    cross_dict['Cross No'] = i
    all_name = []
    if args.all_path == 1:
        for j in range(0, 1000, 100):
            all_name.append('model_best{}'.format(j))
            curr_on = on_df['model_best{}.keras'.format(j)].tolist()
            curr_cross = cross_df['model_best{}.keras'.format(j)].tolist()
            on_val = np.mean(curr_on)
            cross_val = np.mean(curr_cross)
            

            if args.type == 1:
                abs_val = np.abs(np.array(curr_on)-np.array(curr_cross))
                #print(len(abs_val))
                abs_mean = np.mean(abs_val)
                print(abs_mean)
                on_dict['model_best{}'.format(j)] = abs_mean
            else:
                on_dict['model_best{}'.format(j)] = on_val
                cross_dict['model_best{}'.format(j)] = cross_val
    else:
        curr_on = on_df['Mean_Rank'].tolist()
        curr_cross = cross_df['Mean_Rank'].tolist()
        on_val = np.mean(curr_on)
        cross_val = np.mean(curr_cross)
        on_dict['Mean_Rank'] = on_val
        cross_dict['Mean_Rank'] = cross_val
    all_on_dict.append(on_dict)
    all_cross_dict.append(cross_dict)

res_df = pd.DataFrame(columns=all_name)
print(len(all_on_dict))
for i in range(1,args.num_iteration+1):
    res_df = pd.concat([res_df, pd.DataFrame([all_on_dict[i-1]])], ignore_index=True)

for i in range(1,args.num_iteration+1):
    res_df = pd.concat([res_df, pd.DataFrame([all_cross_dict[i-1]])], ignore_index=True)

if args.type == 0:
    typetext = 'mean'
else:
    typetext = 'abs'
res_df.to_csv('result_{}_{}_{}_{}_{}.csv'.format(args.num_sample, typetext, args.num_iteration, args.random, args.minmax))
    #df.to_csv(save_dict)
    #df = pd.DataFrame(rank_dict)
exit()
