import os
import time
import random
import argparse
import numpy as np
from PIL import Image

import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
from sklearn.metrics import roc_auc_score, roc_curve, auc, accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
import sys
sys.path.append('../') 
import torch.distributed as dist
import models
import torch.nn as nn
from tqdm import tqdm as tq
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_curve
import yaml
from utils import *
import torch.nn.functional as F
import tensorflow as tf
import matplotlib.pyplot as plt


def reduce_tensor(tensor, n):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt

# custom attack
from attacks.linf import DI_fgsm, GA_DI_fgsm,   TDI_fgsm, GA_TDI_fgsm, TDMI_fgsm, GA_TDMI_fgsm
# from attacks.feature import DI_FSA, DMI_FSA, GA_DMI_FSA, GA_DI_FSA, Feature_Adam_Attack
# from perceptual_advex.attacks import ReColorAdvAttack

from utils import MyCustomDataset, get_architecture, Input_diversity, MultiEnsemble, get_dataset, get_model
from utils import CrossEntropyLoss, MarginLoss

parser = argparse.ArgumentParser(description='PyTorch Unrestricted Attack')
parser.add_argument('--config', type=str)
parser.add_argument('--world-size', type=int)

NUM_CLASSES = 10

def normalize(item):
    max = item.max()
    min = item.min()
    return (item - min) / (max - min)

def softmax_by_row(logits, T = 1.0):
    mx = np.max(logits, axis=-1, keepdims=True)
    exp = np.exp((logits - mx)/T)
    denominator = np.sum(exp, axis=-1, keepdims=True)
    return exp/denominator


def main(config):
    start = time.time()
    config = parse_config(args.config)
    num_classes = config.num_classes
    # Target model
   
    # Source model
    source_id_list = [int(item) for item in config.source_list.split('_')]
    print("Source id list: {}".format(source_id_list))
    
    # Auxiliary model
    auxiliary_id_list = [int(item) for item in config.auxiliary_list.split('_')]
    print("Auxiliary id list: {}".format(auxiliary_id_list))

    stats_all_attacks=dict()

    #get the name of config dir
    config_path=args.config
    all_dir_names = get_folder_names(config_path)
    config_dir_name = all_dir_names[2] #trainer
    config_second_dir_name = all_dir_names[3] #defense
    config_third_dir_name = all_dir_names[4] #attack config
    
    thres_list = config.thres_list
    success_shots=[]

    all_precisions_one_shot=[]
    all_recalls_one_shot=[]
    all_f1_one_shot=[]
    all_tpr=[]
    all_fpr=[]
    all_ssr_mem=[]
    all_ssr_non=[]

    # attack_methods  = ['GA_MI_fgsm','GA_TI_fgsm','GA_DI_fgsm', 'GA_TDMI_fgsm', 'GA_SI_fgsm', 'GA_TDMSI_fgsm', 'GA_TDMSAI_fgsm']
    attack_methods  = [ 'GA_TDMI_fgsm']
    all_mem_corr=[]
    all_non_corr=[]
    for thres in thres_list:
        print(f'************thresh_{thres}***************')
        mem_corr = []
        non_corr = []
        precisions_one_shot=[]
        recalls_one_shot=[]

        for attack_method in attack_methods:
            print(f'##########{attack_method}###########')
            data_path = f'./data/{config_dir_name}/{config_second_dir_name}/{config_third_dir_name}/target/{attack_method}/thres_{thres}'
            output_train_benign, train_label, output_train_adversarial, transfer_train_pertub_transfer,\
            output_test_benign, test_label, output_test_adversarial, transfer_test_pertub_transfer = get_outputs_labels(args.world_size, data_path)
            mem_correctness = np.argmax(output_train_adversarial,1) == train_label
            non_correctness = np.argmax(output_test_adversarial,1) == test_label
            # mem_correctness = np.argmax(output_train_adversarial,1) == np.argmax(output_train_benign,1)
            # non_correctness = np.argmax(output_test_adversarial,1) == np.argmax(output_test_benign,1)
            # print(sum(mem_correctness), sum(non_correctness))
            mem_corr.append(mem_correctness)
            non_corr.append(non_correctness)
        
            adv_train_acc = accuracy(output_train_adversarial,train_label)
            adv_test_acc = accuracy(output_test_adversarial,test_label)
            precision_one_shot = adv_train_acc/(adv_train_acc+adv_test_acc)
            recall_one_shot = adv_train_acc
            print('benign train test acc:', accuracy(output_train_benign,train_label), accuracy(output_test_benign,test_label))
            print('adv train test acc:', adv_train_acc, adv_test_acc)
            print(f'One-shot Precision: {precision_one_shot} | Recall: {recall_one_shot}')
            precisions_one_shot.append(precision_one_shot)
            recalls_one_shot.append(recall_one_shot)

            output_all = np.concatenate([output_train_adversarial, output_test_adversarial])
            label_all = np.concatenate([train_label, test_label])
            predicted_correctness = np.argmax(output_all,1)!=label_all  
            success_shots.append(predicted_correctness)
    
        all_mem_corr.append(mem_corr)
        all_non_corr.append(non_corr)

        mem_corr=np.sum(np.array(mem_corr), 0)
        non_corr=np.sum(np.array(non_corr), 0)

    all_mem_corr=np.array(all_mem_corr).transpose(1,0,2)
    all_non_corr=np.array(all_non_corr).transpose(1,0,2)

    for i in range(len(attack_methods)):
        thre_precision_one_shot=[]
        thre_recall_one_shot=[]
        thre_f1_one_shot=[]
        thre_tpr=[]
        thre_fpr=[]
        thre_ssr_mem=[]
        thre_ssr_non=[]
        for j in range(len(thres_list)):
            predicted_membership = np.concatenate([all_mem_corr[i][j],all_non_corr[i][j]])
            ground_membership = np.concatenate([np.ones(len(all_mem_corr[i][j],)),np.zeros(len(all_non_corr[i][j]))])
            precision_one_shot, recall_one_shot, f1_one_shot = precision_score(ground_membership,predicted_membership,pos_label=1), recall_score(ground_membership,predicted_membership,pos_label=1), f1_score(ground_membership,predicted_membership,pos_label=1)
            
            _confusion_matrix = confusion_matrix(ground_membership,predicted_membership)
            _tp = _confusion_matrix[1, 1]
            _fn = _confusion_matrix[1, 0]
            _fp = _confusion_matrix[0, 1]
            _tn = _confusion_matrix[0, 0]
            _tpr = _tp / (_tp + _fn)
            _fpr = _fp / (_tn + _fp)
    
            thre_precision_one_shot.append(precision_one_shot)
            thre_recall_one_shot.append(recall_one_shot)
            thre_f1_one_shot.append(f1_one_shot)
            thre_tpr.append(_tpr)
            thre_fpr.append(_fpr)
            thre_ssr_mem.append(np.sum(all_mem_corr[i][j])/len(all_mem_corr[i][j]))
            thre_ssr_non.append(np.sum(all_non_corr[i][j]/len(all_non_corr[i][j])))

        all_precisions_one_shot.append(thre_precision_one_shot)
        all_recalls_one_shot.append(thre_recall_one_shot)
        all_f1_one_shot.append(thre_f1_one_shot)
        all_tpr.append(thre_tpr)
        all_fpr.append(thre_fpr)
        all_ssr_mem.append(thre_ssr_mem)
        all_ssr_non.append(thre_ssr_non)
    
    for i in range(len(attack_methods)):
        stats_all_attacks[f'One-Shot Attack-{attack_methods[i][3:-5]} (Ours)']=[all_precisions_one_shot[i], all_recalls_one_shot[i], all_f1_one_shot[i], all_tpr[i], all_fpr[i], all_ssr_mem[i], all_ssr_non[i]]

    print('*********stats_all************')
    print(stats_all_attacks)

    import seaborn as sns
    graph_save_path = './graph'
    os.makedirs(graph_save_path, exist_ok=True)

    # plot pre and rec for all attacks
    pre_rec_plot(stats_all_attacks)

    # # plot TPR and FPR for all attacks
    tpr_fpr_plot(stats_all_attacks)


def pre_rec_plot(stats_all_attacks, save_path='./graph/pre_rec.png'):
    # plot pre and rec for all attacks
    plt.figure(figsize=(4,3), dpi= 200)
    for key in stats_all_attacks:
        if not(key.startswith('One') or key.startswith('Mul')):
            # plt.plot(stats_all_attacks[key][1], stats_all_attacks[key][0], label=key, alpha=.7, ls='--') # dashed lines for other attacks
            plt.plot(stats_all_attacks[key][1], stats_all_attacks[key][0], label=key, alpha=.7) # dashed lines for other attacks
        else:
            plt.plot(stats_all_attacks[key][1], stats_all_attacks[key][0], label=key, alpha=.7)
    plt.xlabel("Recall",fontsize=10)#横坐标名字
    plt.ylabel("Precision",fontsize=10)
    plt.ylim(0,1.0)
    plt.xlim(0.03,1.0)
    # plt.title(f'dist',fontsize=20)
    plt.legend(loc = "best",fontsize=8)
    plt.savefig(save_path, dpi=300,bbox_inches='tight')

def tpr_fpr_plot(stats_all_attacks, save_path='./graph/fprtpr.png'):
    print('*********stats_all_TPR_FPR************')
    plt.figure(figsize=(4,3), dpi= 200)
    for key in stats_all_attacks:
            print(f'*********{key}************')
            if key.startswith('One'):
                valid_idx = np.argmax(np.array(stats_all_attacks[key][4])<= 0.001)
                print('valid:',valid_idx)
                TPR = np.flip(np.array(stats_all_attacks[key][3]))
                FPR = np.flip(np.array(stats_all_attacks[key][4]))
                print('TPR',stats_all_attacks[key][3], TPR)
                print('FPR',stats_all_attacks[key][4], FPR)
            else:
                TPR = np.array(stats_all_attacks[key][3])
                FPR = np.array(stats_all_attacks[key][4])
            print('TPR under 0.001 FPR:', np.interp(0.001,FPR,TPR))
            print('TPR under 0.01 FPR:', np.interp(0.01,FPR,TPR))
            plt.plot(stats_all_attacks[key][4], stats_all_attacks[key][3], label=key, alpha=.7) 
    plt.semilogx()
    plt.semilogy()
    plt.xlim(1e-3,1)
    plt.ylim(1e-3,1)
    plt.xlabel("False Positive Rate", fontsize=10)
    plt.ylabel("True Positive Rate", fontsize=10)
    plt.plot([0, 1], [0, 1], ls='--', color='gray')
    # plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
    plt.legend(loc = "best",fontsize=8)
    plt.savefig(save_path, bbox_inches='tight')

def reasoning_plot(stats_all_attacks, thres_list):
    metric_dictionary = np.load('other_label_only_analysis.npy',allow_pickle='TRUE').item()
    for key in metric_dictionary:
        mem_logits, nonmem_logits = metric_dictionary[key][:len(metric_dictionary[key])//2], metric_dictionary[key][len(metric_dictionary[key])//2:]
        plt.figure(figsize=(8,5), dpi= 100)
        bins = np.histogram_bin_edges(nonmem_logits, bins=50)
        n1,bins1,_ = plt.hist(mem_logits, bins=bins, label="Train Samples", alpha=.7,cumulative=-1,density=True,histtype='bar')
        n2,bins2,_ = plt.hist(nonmem_logits, bins=bins, label="Test Samples", alpha=.7,cumulative=-1,density=True,histtype='bar')
        plt.xlabel("Magnitude of Perturbation",fontsize=15)#横坐标名字
        plt.ylabel("Cumulative Fraction",fontsize=15)
        plt.title(f'Boundary Attack',fontsize=20)
        plt.legend(loc = "best",fontsize=15)
        plt.savefig(f'./graph/analysis_{key}.png',dpi=300,bbox_inches='tight')

    plt.figure(figsize=(8,5), dpi= 100)
    key='TDMI'
    mem_logits=stats_all_attacks[f'One-Shot Attack-TDMI (Ours)'][5]
    nonmem_logits=stats_all_attacks[f'One-Shot Attack-TDMI (Ours)'][6]
    plt.bar(np.arange(len(thres_list)), mem_logits, label="Train Samples", alpha=.7)
    plt.bar(np.arange(len(thres_list)), nonmem_logits, label="Test Samples", alpha=.7)
    plt.xlabel("Threshold",fontsize=15)#横坐标名字
    plt.ylabel("Failure Rate of Adversarial Attack",fontsize=15)
    plt.xticks(np.arange(len(thres_list)),thres_list)
    plt.title(f'One-Shot Attack',fontsize=20)
    plt.legend(loc = "best",fontsize=15)
    plt.savefig(f'./graph/analysis_{key}.png',dpi=300,bbox_inches='tight')


def accuracy(outputs,y_true):
    y_pred = np.argmax(outputs,axis=1)
    return accuracy_score(y_true,y_pred)
    

def get_outputs_labels(world_size, data_path):
    output_train_benign=[]
    train_label=[]
    output_train_adversarial=[]
    transfer_train_pertub_transfer=[]

    output_test_benign=[]
    test_label=[]
    output_test_adversarial=[]
    transfer_test_pertub_transfer=[]

    for rank in range(world_size):
        path = os.path.join(data_path, f'world_size{world_size}_rank{rank}.npz')
        data = np.load(path)

        output_train_benign.append(data['output_train_benign'])
        train_label.append(data['train_label'])
        output_train_adversarial.append(data['output_train_adversarial'])
        transfer_train_pertub_transfer.append(data['transfer_train_pertub_transfer'])

        output_test_benign.append(data['output_test_benign'])
        test_label.append(data['test_label'])
        output_test_adversarial.append(data['output_test_adversarial'])
        transfer_test_pertub_transfer.append(data['transfer_test_pertub_transfer'])

    output_train_benign=np.concatenate(output_train_benign)
    train_label=np.concatenate(train_label)
    output_train_adversarial=np.concatenate(output_train_adversarial)
    transfer_train_pertub_transfer=np.concatenate(transfer_train_pertub_transfer)

    output_test_benign=np.concatenate(output_test_benign)
    test_label=np.concatenate(test_label)
    output_test_adversarial=np.concatenate(output_test_adversarial)
    transfer_test_pertub_transfer=np.concatenate(transfer_test_pertub_transfer) 

    return output_train_benign, train_label, output_train_adversarial, transfer_train_pertub_transfer,\
            output_test_benign, test_label, output_test_adversarial, transfer_test_pertub_transfer
    
def get_target_shadow_outputs_labels(world_size, data_path):
    output_shadow_mem=[]
    label_shadow_mem=[]
    output_shadow_nonmem=[]
    label_shadow_nonmem=[]

    output_target_mem=[]
    label_target_mem=[]
    output_target_nonmem=[]
    label_target_nonmem=[]

    for rank in range(world_size):
        path = os.path.join(data_path, f'world_size{world_size}_rank{rank}.npz')
        data = np.load(path)

        output_shadow_mem.append(data['output_shadow_mem'])
        label_shadow_mem.append(data['label_shadow_mem'])
        output_shadow_nonmem.append(data['output_shadow_nonmem'])
        label_shadow_nonmem.append(data['label_shadow_nonmem'])

        output_target_mem.append(data['output_target_mem'])
        label_target_mem.append(data['label_target_mem'])
        output_target_nonmem.append(data['output_target_nonmem'])
        label_target_nonmem.append(data['label_target_nonmem'])

    output_shadow_mem=np.concatenate(output_shadow_mem)
    label_shadow_mem=np.concatenate(label_shadow_mem)
    output_shadow_nonmem=np.concatenate(output_shadow_nonmem)
    label_shadow_nonmem=np.concatenate(label_shadow_nonmem)

    output_target_mem=np.concatenate(output_target_mem)
    label_target_mem=np.concatenate(label_target_mem)
    output_target_nonmem=np.concatenate(output_target_nonmem)
    label_target_nonmem=np.concatenate(label_target_nonmem)

    return output_shadow_mem, label_shadow_mem, output_shadow_nonmem, label_shadow_nonmem, output_target_mem, label_target_mem, output_target_nonmem, label_target_nonmem

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def parse_config(config_path=None):
    with open(config_path, 'r') as f:
        config = yaml.load(f, Loader=yaml.Loader)
        new_config = dict2namespace(config)
    return new_config

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)

