# -*- encoding: utf-8 -*-
'''
@File: 3090_detect_adv_independent.py
@Description: Detection AEs by SSL representation
@Time: 2022/07/03 16:12:13
@Author: Zhiyuan He
'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms

from utils.resnet_factory import SimSiamWithCls
from resnet18_32x32 import ResNet18_32x32


import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import auc
import random
import os
import math
import argparse


def plot_roc(auc_metrics, labels, path):


    auc_metrics = np.array(auc_metrics)
    plt.figure()
    x = auc_metrics[0]
    x_sort_index = x.argsort()
    x = np.append(np.insert(x[x_sort_index], 0, 0.), 1.0)

    ys = auc_metrics[1:]
    for (y, label) in zip(ys, labels):
        y = np.append(np.insert(y[x_sort_index], 0, 0.), 1.0)

        plt.plot(x, y, lw=2, label=label)
        print(f"For {label}, AUC: {auc(x, y)}")
    plt.plot([0, 1],[0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0,1.05])
    plt.xlabel( 'False Positive Rate')
    plt.ylabel( 'True Positive Rate')
    plt.title( 'Receiver operating characteristic example' )
    plt.legend (loc="lower right")
    plt.savefig(path)


def multi_transform(img, transforms, times=50):

    return torch.stack([transforms(img) for t in range(times)], dim=1)

def get_cos_similarity_3d(m1, m2):

    n = torch.matmul(m1, m2.transpose(1,2))
    dnorm = torch.norm(m1,p=2, dim=2).unsqueeze(dim=2) * torch.norm(m2, p=2, dim=2).unsqueeze(dim=1)

    res = n/dnorm
    return res

def detect_adv_by_representation(tar_labels, aug_labels, sim_with_ori, aug_time):

    adv_num = 0
    total_num = len(tar_labels)
    aug_labels_mode = aug_labels.mode(dim=1)[0]
    aug_ne_tar = (aug_labels_mode != tar_labels)

    # adv_num += aug_ne_tar.sum().item()
    print(f'First Detection: {aug_ne_tar.sum().item()}/{total_num}')

    # sim_with_ori = sim_with_ori[~aug_ne_tar, :]
    # print(sim_with_ori.mean(-1).mean())

    rep_sim = [0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    counts = [10, 15, 20, 25, 30, 35]

    # thresholds = [(0.15, 30.0), (0.2, 35.0),
    #             (0.25, 35.0), (0.25, 30.0),
    #             (0.3, 35.0), (0.25, 20.0),
    #             (0.4, 35.0), (0.4, 30.0), (0.4, 25.0),
    #             (0.5, 35.0), (0.5, 30.0),
    #             (0.6, 35.0), (0.6, 25.0),
    #             (0.7, 35.0), (0.7, 30.0), (0.7, 25.0),
    #             (0.8, 35.0), (0.8, 30.0), (0.8, 25.0),
    #             (0.9, 35.0),
    #             (0.9, 30.0), (0.9, 25.0), (0.9, 15.0), (0.9, 5.0)]
    thresholds = [(r, c) for r in rep_sim for c in counts]

    auc_metrics = []
    for t_sim, t_c in thresholds:
    # for t in thresholds:
        # print("-"*20)
        sim_lt_t = (sim_with_ori < t_sim).sum(-1)
        # print(t, sim_lt_t.float().mean().item())

        # for c in counts:
        sim_lt_adv = (sim_lt_t > t_c).sum()
        adv_single_sam = sim_lt_adv.item()
        print(adv_single_sam/total_num)
        auc_metrics.append(adv_single_sam/total_num)

    return auc_metrics


def detect_adv_by_label_sim(tar_labels, ssl_labels, aug_labels, aug_time):
    
    total_num = len(tar_labels)
    aug_labels_mode = aug_labels.mode(dim=1)[0]
    aug_ne_tar = (aug_labels_mode != tar_labels)
    # aug_ne_ssl = (aug_labels_mode != ssl_labels)

    # adv_num += aug_ne_tar.sum().item()
    print(f'First Detection: {aug_ne_tar.sum().item()}/{total_num}')

    aug_labels = aug_labels.reshape(total_num, aug_time)

    aug_eq_tar = (aug_labels == tar_labels.unsqueeze(dim=1))
    # print(aug_eq_tar.sum(dim=-1).float().mean())

    auc_metrics = []
    for threshold in range(0, aug_time+1, 1):

        aug_tar_lt_threshold = (aug_eq_tar.sum(dim=-1)<threshold).sum()
        auc_metrics.append((aug_tar_lt_threshold.item())/total_num)

    print("detect between target label and aug labels.")
    # print(auc_metrics[11:15])

    return auc_metrics

def detect_adv_by_rep_label(tar_labels, ssl_labels, aug_labels, sim_with_ori, aug_time):
    
    adv_num = 0
    total_num = len(tar_labels)
    aug_labels_mode = aug_labels.mode(dim=1)[0]
    aug_ne_tar = (aug_labels_mode != tar_labels)
    # aug_ne_ssl = (aug_labels_mode != ssl_labels)

    adv_num += aug_ne_tar.sum().item()
    print(f'First Detection: {adv_num}/{total_num}')

    tar_labels = tar_labels[~aug_ne_tar]
    # ssl_labels = ssl_labels[~aug_ne_tar]
    aug_labels = aug_labels.reshape(total_num, aug_time)[~aug_ne_tar, :]
    sim_with_ori = sim_with_ori[~aug_ne_tar, :]

    auc_metrics = []
    thresholds = [0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

    step = 2 if aug_time < 25 else 5
    counts = list(range(0, aug_time, step))

    # thresholds = [(0.15, 30.0), (0.2, 35.0),
    #             (0.25, 35.0), (0.25, 30.0),
    #             (0.3, 35.0), (0.25, 20.0),
    #             (0.4, 35.0), (0.4, 30.0), (0.4, 25.0),
    #             (0.5, 35.0), (0.5, 30.0),
    #             (0.6, 35.0), (0.6, 25.0),
    #             (0.7, 35.0), (0.7, 30.0), (0.7, 25.0),
    #             (0.8, 35.0), (0.8, 30.0), (0.8, 25.0),
    #             (0.9, 35.0),
    #             (0.9, 30.0), (0.9, 25.0), (0.9, 15.0), (0.9, 5.0)]
    thresholds = [(t, c) for t in thresholds for c in counts]
    auc_metrics = []
    for t_sim, t_c in thresholds:
    # for t in thresholds:
        # print("-"*20)
        sim_lt_t = (sim_with_ori < t_sim).sum(-1)
        # print(t, sim_lt_t.float().mean().item())

        # for c in counts:
        sim_lt_adv = (sim_lt_t > t_c).sum()
        # adv_single_sam = sim_lt_adv.item()+adv_num
        adv_single_sam = sim_lt_adv.item()+adv_num
        auc_metrics.append(adv_single_sam/total_num)

    return auc_metrics

def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    normalization = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    img_transforms = transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        normalization
    ])

    

    target_model = ResNet18_32x32()
    target_model.load_state_dict(torch.load('./weights/raw_9153.pth'))
    target_model.to(device)

    model = SimSiamWithCls()
    model.load_state_dict(torch.load('./weights/simsiam-cifar10.pth'))
    model.to(device)

    backbone = model.backbone
    classifer = model.classifier
    projector = model.projector



    aug_times = [5, 10, 25, 50, 80]
    e = 16
    auc_metrics = []

    for aug_time in aug_times:

        # auc_metrics = []
        print(f"Aug Time: {aug_time}")
        batch_size = 100 if aug_time <50 else 50
        attacks = ['clean', 'ada']
        with torch.no_grad():

            target_model.eval()
            backbone.eval()
            classifer.eval()
            projector.eval()

            for attack in attacks:

                if attack == "clean":
                    samples = torch.from_numpy(np.load('./AEs/clean_inputs.npy'))
                    labels = torch.max(torch.from_numpy(np.load('./AEs/clean_labels.npy')), -1)[1]
                elif attack == 'ada':
                    samples = torch.from_numpy(np.load(f'./AEs/ssl/Ada_a-1_e{e}_at{aug_time}_norm_s0002_AdvSamples.npy'))
                    labels = torch.from_numpy(np.load(f'./AEs/ssl/Ada_a-1_e{e}_at{aug_time}_norm_s0002_AdvLabels.npy'))
                else:
                    print("Unknown Attacks")
                    break

                
                print('----------------', attack)
                samples = samples.to(device)
                natural_labels = torch.max(torch.from_numpy(np.load('./AEs/clean_labels.npy')), -1)[1].to(device)

                # if attack.startswith('ada'):
                #     target_model.eval()
                #     pred, _ = target_model(samples.float())
                #     labels = torch.max(pred, -1)[1]
                #     print(f'Transfer to target model, attack success rate = {(len(labels) - (labels == natural_labels).sum())/len(labels)}')
                preds = target_model(samples)
                labels = preds.max(-1)[1]
                labels = labels.to(device)

                
                # natural_labels = torch.max(torch.from_numpy(np.load('./AEs/raw_norm/clean_labels.npy')), -1)[1].to(device)
            
                mask = (labels == natural_labels) if attack == 'clean' else (labels != natural_labels)

                samples = samples[mask, :, :, :]
                labels = labels.masked_select(mask)

                print("Success AEs Num:", len(labels))

                number_batch = int(math.ceil(len(samples) / batch_size))
                sim_with_ori = torch.Tensor().to(device)
                sim_representation = torch.Tensor().to(device)

                ssl_repres = torch.Tensor().to(device)
                aug_repres = torch.Tensor().to(device)
                ssl_labels = torch.Tensor().to(device)
                aug_labels = torch.Tensor().to(device)


                for index in range(number_batch):
                    start = index * batch_size
                    end = min((index + 1) * batch_size, len(samples))

                    trans_images = multi_transform(samples[start:end], img_transforms, times=aug_time).to(device)

                    ssl_backbone_out = backbone(normalization(samples[start:end]).to(device))
                    # ssl_backbone_out, _ = backbone(samples[start:end].to(device))

                    ssl_repre = projector(ssl_backbone_out)
                    ssl_label = classifer(ssl_backbone_out)
                    ssl_label = torch.max(ssl_label, -1)[1]

                    aug_backbone_out = backbone(trans_images.reshape(-1, 3, 32, 32))
                    aug_repre = projector(aug_backbone_out)
                    aug_label = classifer(aug_backbone_out)
                    aug_label = torch.max(aug_label, -1)[1]
                    aug_label = aug_label.reshape(end-start, aug_time)

                    sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(end-start, aug_time, -1), dim=2)
                    # sim_repre = F.cosine_similarity(ssl_backbone_out.unsqueeze(dim=1), aug_backbone_out.reshape(end-start, aug_time, -1), dim=2)
                    
                    sim_aug = get_cos_similarity_3d(aug_repre.reshape(end-start, aug_time, -1), aug_repre.reshape(end-start, aug_time, -1))

                    ssl_labels = torch.cat([ssl_labels, ssl_label], dim=0)
                    aug_labels = torch.cat([aug_labels, aug_label], dim=0)
                    ssl_repres = torch.cat([ssl_repres, ssl_repre], dim=0)
                    aug_repres = torch.cat([aug_repres, aug_repre], dim=0)
                    # ssl_repres = torch.cat([ssl_repres, ssl_backbone_out], dim=0)
                    # aug_repres = torch.cat([aug_repres, aug_backbone_out], dim=0)
                    sim_with_ori = torch.cat([sim_with_ori, sim_repre], dim=0)
                    sim_representation = torch.cat([sim_representation, sim_aug], dim=0)

                    # print((aug_repre - ssl_repre.unsqueeze(dim=1)).norm(dim=-1, p=2).mean())
                    # print(sim.sum(dim=-1).float().mean())
                    # print(f'For Iteration {index}, Mean---Similarity with Original Image: {sim.mean()}, Mean---Representation Similarity: {sim_repre.mean()}')
                
                
                print(f'{attack}, Similarity with Ori Mean: {sim_with_ori.mean(dim=-1).mean()}, Variance: {sim_with_ori.var(-1).mean()}')
                print(f'{attack}, Target Model Label equals Aug Labels Count: {(labels.unsqueeze(dim=1) == aug_labels).sum(-1).float().mean()}')
                print(f'{attack}, SSL Model Label equals Aug Labels Count: {(ssl_labels.unsqueeze(dim=1) == aug_labels).sum(-1).float().mean()}')

                # auc_metric = detect_adv_by_rep_label(labels, ssl_labels, aug_labels, sim_with_ori, aug_time)
                auc_metric = detect_adv_by_label_sim(labels, ssl_labels, aug_labels, aug_time)
                # # auc_metric = detect_adv_by_representation(labels, aug_labels, sim_with_ori, aug_time)
                auc_metrics.append(auc_metric)

                # plot_roc(auc_metrics, attacks[1:], f'./ablation_{aug_time}.png')
        np.save('./auc_c10_at_ablation_e16.npy', auc_metrics)
            # np.save('./auc_ada_cifar10.npy', metrics)


            
            


if __name__ == "__main__":

    main()
