# -*- encoding: utf-8 -*-
'''
@File: orthogonal_pgd.py
@Description: 
@Time: 2022/07/25 12:46:22
@Author: Zhiyuan He
'''
from ast import parse
import os 
import math
from turtle import forward
import numpy as np
import argparse
import pickle
import random
import sklearn
from sklearn.metrics import roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm.auto import tqdm

from utils.resnet_factory import ResNet18, SimSiamWithCls
from orthogonal_pgd import PGD

def multi_transform(img, transforms, times=50):

    return torch.stack([transforms(img) for t in range(times)], dim=1)

def get_cos_similarity_3d(m1, m2):

    n = torch.matmul(m1, m2.transpose(1,2))
    dnorm = torch.norm(m1,p=2, dim=2).unsqueeze(dim=2) * torch.norm(m2, p=2, dim=2).unsqueeze(dim=1)

    res = n/dnorm
    return res

class EasyDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

class Detector(nn.Module):

    def __init__(self, ssl_model = None, target_model=None, augmentation=None, aug_time=50, batch_size=100, device=None):
        
        super(Detector, self).__init__()
        self.target_model = target_model
        self.model = ssl_model
        self.model.eval()
        self.augmentation = augmentation
        self.aug_time = aug_time
        self.batch_size = batch_size
        self.device = device
    
    
    def detect_adv_by_representation(self, ori_labels, aug_labels, sim_with_ori, sim_representation):

        total_num = len(ori_labels)
        scores = torch.zeros(total_num)
        # print(len(scores))
        
        aug_labels_mode = aug_labels.mode(dim=1)[0]
        aug_ne_tar = (aug_labels_mode != ori_labels)

        scores = scores.masked_fill(aug_ne_tar, 1)
        # print(f"First Detection: {aug_ne_tar.sum()}, {scores.sum()}")
        clean_indices = aug_ne_tar.logical_not().int().nonzero()

        # t_sim, t_c = 0.7, 15 # 20
        t_sim, t_c = 0.8, 30 # 20
        for i in clean_indices:
            
            sim = sim_with_ori[i, :]
            if (sim < t_sim).sum() > t_c:
                scores[i] = 1

        return scores

    def forward(self, samples, clean_labels=None):

        device = self.device
        target_model = self.target_model.to(device)
        backbone = self.model.backbone.to(device)
        classifier = self.model.classifier.to(device)
        projector = self.model.projector.to(device)


        # with torch.no_grad():
        target_model.eval()
        backbone.eval()
        classifier.eval()
        projector.eval()

        preds = target_model(samples.to(device))
        labels = preds.max(-1)[1]
        labels = labels.cpu()

        number_batch = int(math.ceil(len(samples) / self.batch_size))
        sim_with_ori = torch.Tensor()
        sim_representation = torch.Tensor()

        ssl_repres = torch.Tensor()
        aug_repres = torch.Tensor()
        ssl_labels = torch.Tensor()
        aug_labels = torch.Tensor()

        for index in range(number_batch):
            start = index * self.batch_size
            end = min((index + 1) * self.batch_size, len(samples))

            batch_samples = samples[start:end].to(device)
            trans_images = multi_transform(batch_samples, self.augmentation, times=self.aug_time).to(device)

            normalization = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ssl_backbone_out = backbone(normalization(batch_samples))

            ssl_repre = projector(ssl_backbone_out)
            ssl_label = classifier(ssl_backbone_out)
            # ssl_label = torch.max(ssl_label, -1)[1]

            aug_backbone_out = backbone(trans_images.reshape(-1, 3, 32, 32))
            aug_repre = projector(aug_backbone_out)
            aug_label = classifier(aug_backbone_out)
            # aug_label = torch.max(aug_label, -1)[1]
            aug_label = aug_label.reshape(end-start, self.aug_time, -1)

            sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(end-start, self.aug_time, -1), dim=2)
            sim_aug = get_cos_similarity_3d(aug_repre.reshape(end-start, self.aug_time, -1), aug_repre.reshape(end-start, self.aug_time, -1))

            ssl_labels = torch.cat([ssl_labels, ssl_label.cpu()], dim=0)
            aug_labels = torch.cat([aug_labels, aug_label.cpu()], dim=0)
            ssl_repres = torch.cat([ssl_repres, ssl_repre.cpu()], dim=0)
            aug_repres = torch.cat([aug_repres, aug_repre.cpu()], dim=0)
            # ssl_repres = torch.cat([ssl_repres, ssl_backbone_out], dim=0)
            # aug_repres = torch.cat([aug_repres, aug_backbone_out], dim=0)
            sim_with_ori = torch.cat([sim_with_ori, sim_repre.cpu()], dim=0)
            sim_representation = torch.cat([sim_representation, sim_aug.cpu()], dim=0)
        
            
        # print(f'Similarity with Ori Mean: {sim_with_ori.mean(dim=-1).mean()}')
        # # print(f'{attack} Augmented Images Similarity Mean, Mean: {sim_representation.mean()}')
        # print(f'Target Model Label equals Aug Labels Count: {(labels.unsqueeze(dim=1) == aug_labels).sum(-1).float().mean()}')
        # print(f'SSL Model Label equals Aug Labels Count: {(ssl_labels.unsqueeze(dim=1) == aug_labels).sum(-1).float().mean()}')

        # scores = self.detect_adv_by_representation(labels, aug_labels, sim_with_ori, sim_representation)
        # print(ssl_labels.shape, aug_labels.shape, sim_with_ori.shape)
        return ssl_labels, aug_labels, sim_with_ori

def score(num_images, orig, advx_final, detector, device):
    is_adversarial_label = np.concatenate((np.zeros(num_images), np.ones(num_images))).reshape(-1, 1)
    original_and_adversarial_images = torch.cat((orig.clone(), advx_final.clone())).float()

    test_fprs = []
    test_tprs = []


    adv_scores = detector(original_and_adversarial_images.to(device), "correct").detach().cpu().numpy()

    phi_range = [-np.inf] + list(np.sort(adv_scores)) + [np.inf]
    
    threshold_5 = np.sort(adv_scores[:num_images])[num_images-num_images//20-1]
    threshold_10 = np.sort(adv_scores[:num_images])[num_images-num_images//10-1]
    threshold_50 = np.sort(adv_scores[:num_images])[num_images-num_images//2-1]

    tpr_5 = np.mean(adv_scores[num_images:]>threshold_5)
    tpr_10 = np.mean(adv_scores[num_images:]>threshold_10)
    tpr_50 = np.mean(adv_scores[num_images:]>threshold_50)

    fpr_5 = np.mean(adv_scores[:num_images]>threshold_5)
    fpr_10 = np.mean(adv_scores[:num_images]>threshold_10)
    fpr_50 = np.mean(adv_scores[:num_images]>threshold_50)
    
    for phi in phi_range:
        is_adversarial_pred = 1 * (adv_scores > phi)
        is_adversarial_pred = is_adversarial_pred.reshape(-1, 1)
        
        TP_count = np.sum(is_adversarial_pred*is_adversarial_label)
        TN_count = np.sum((1-is_adversarial_pred)*(1-is_adversarial_label))
        FP_count = np.sum((is_adversarial_pred)*(1-is_adversarial_label))
        FN_count = np.sum((1-is_adversarial_pred)*(is_adversarial_label))

        assert TP_count + TN_count + FP_count + FN_count == num_images*2
        tpr = TP_count/(TP_count+FN_count) if TP_count+FN_count != 0 else 0
        fpr = FP_count/(FP_count+TN_count) if FP_count+TN_count != 0 else 0

        test_tprs.append(tpr)
        test_fprs.append(fpr)
    return tpr_5, tpr_10, tpr_50, auc(test_fprs, test_tprs), adv_scores

def run_experiment(taget_model, detector, samples, labels, batch_size=30, device=None, mode='select', **attack_args):
    
    pgd = PGD(taget_model, detector, device=device, **attack_args)

    advx = pgd.attack(samples.clone(), labels, batch_size)

    if 'target' in attack_args and attack_args['target'] is not None:
        attack_succeeded = (taget_model(advx.to(device)).argmax(1).cpu()==attack_args['target'])
    else:
        attack_succeeded = (taget_model(advx.to(device)).argmax(1).cpu()!=labels)

    sr = torch.mean(attack_succeeded.float())
    print(f"attack success rate: {sr}")
    
    return advx

    # tpr5, tpr10, tpr50, auc, scores = score(len(samples), samples, advx, detector, device)
    # return EasyDict(success=list(attack_succeeded.numpy()),
    #                 score=list(scores),
    #                 sr=sr.item(),
    #                 tpr5=tpr5,
    #                 tpr10=tpr10,
    #                 tpr50=tpr50)

def main(mode, epsilon, n_ae=100, target=True, gpu=None):

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.empty_cache()

    device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu

    samples = torch.tensor(np.load('./AEs/clean_inputs.npy'), dtype=torch.float32)
    labels = torch.tensor(np.load('./AEs/clean_labels.npy'), dtype=torch.int64)
    target_labels = torch.roll(labels, 1, 1)

    labels = labels.max(-1)[1]
    target_labels = target_labels.max(-1)[1]

    normalization = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    img_transforms = transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        normalization
    ])

    target_model = ResNet18()
    target_model.load_state_dict(torch.load('./weights/raw_9153.pth'))
    target_model.to(device)
    target_model.eval()

    ssl_model = SimSiamWithCls()
    ssl_model.load_state_dict(torch.load('./weights/simsiam-cifar10.pth'))
    ssl_model.to(device)
    ssl_model.eval()

    detector = Detector(ssl_model, target_model, augmentation=img_transforms, aug_time=50, batch_size=50, device=device)
    # socres = detector(samples)


    if mode == 'select':
        d = {'use_projection': True, 'eps': epsilon, 'alpha': .001, 'steps': 1000,
            'projection_norm': 'linf'
        }
    
    else: 
        d = {'use_projection': True, 'eps': epsilon, 'alpha': .001, 'steps': 1000,
            'projection_norm': 'linf', 'project_detector': True, 'project_classifier': True
        }


    if target:
        advx = run_experiment(target_model, detector, samples[:n_ae], target_labels[:n_ae], batch_size=64,
                            device=device, mode=mode,
                            classifier_loss=nn.CrossEntropyLoss(),
                            detector_loss=None,
                            target=target_labels[:n_ae],
                            **d)
    else:
        advx = run_experiment(target_model, detector, samples[:n_ae], labels[:n_ae], batch_size=64,
                            device=device, mode=mode,
                            classifier_loss=nn.CrossEntropyLoss(),
                            detector_loss=None,
                            target=None,
                            **d)
    np.save(f'./AEs/ssl/Opgd_{mode}_tar_e{str(epsilon).replace(".", "")}_AdvSamples.npy', advx.detach().cpu().numpy())


if __name__ == "__main__":

    # parser = argparse.ArgumentParser(description='Orthogonal PGD')
    # parser.add_argument('-m', '--mode', default='select', type=str)
    # parser.add_argument('-bs', '--batch_size', default=1000, type=int)
    # parser.add_argument('-e', '--epsilon', default=0.01, type=float)
    # parser.add_argument('-t', '--target', action='store_true')
    # parser.add_argument('-g', '--gpu', default='0', type=str)
    # args = parser.parse_args()

    gpu = '0'
    n_ae = 1000
    target = True
    for mode in ['select', 'orth']:
        for e in (0.01, 8/255):
            print(f'========================={mode.upper()}, E={e}, =============================')
            main(mode, e, n_ae=n_ae, target=target, gpu=gpu)

