import os
import sys
import time

from utils import DatasetNumpy


import random

import PIL
from torch import optim
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm

from core.defenses.STRIP import STRIP
from core.defenses.Frequency import Frequency
# from core.defenses.Lava_D import LAVA

import argparse
import core

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder
from torchvision.transforms import Compose, RandomHorizontalFlip, ToTensor, ToPILImage, Resize
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.utils.data import Subset

import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.metrics import roc_auc_score
import numpy as np
import cv2
import torchvision.models as models
from sklearn.metrics import precision_score, recall_score

class GetPoisonedDataset(torch.utils.data.Dataset):
    """Construct a dataset.

    Args:
        data_list (list): the list of data.
        labels (list): the list of label.
    """
    def __init__(self, data_list, labels):
        self.data_list = data_list
        self.labels = labels

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        img = torch.FloatTensor(self.data_list[index])
        label = torch.FloatTensor(self.labels[index])
        return img, label


def read_image(img_path, type=None):
    img = cv2.imread(img_path)
    if type is None:
        return img
    elif isinstance(type,str) and type.upper() == "RGB":
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    elif isinstance(type,str) and type.upper() == "GRAY":
        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        raise NotImplementedError

def gen_grid(height, k, intensity = 1):
    """Generate an identity grid with shape 1*height*height*2 and a noise grid with shape 1*height*height*2
    according to the input height ``height`` and the uniform grid size ``k``.
    """
    ins = torch.rand(1, 2, k, k) * 2 - 1
    ins = ins / torch.mean(torch.abs(ins))  # a uniform grid
    noise_grid = nn.functional.upsample(ins, size=height, mode="bicubic", align_corners=True)
    noise_grid = intensity * noise_grid.permute(0, 2, 3, 1)  # 1*height*height*2
    array1d = torch.linspace(-1, 1, steps=height)  # 1D coordinate divided by height in [-1, 1]
    x, y = torch.meshgrid(array1d, array1d)  # 2D coordinates height*height
    identity_grid = torch.stack((y, x), 2)[None, ...]  # 1*height*height*2

    return identity_grid, noise_grid


def prepare_dataset(args):


    if args.dataset == "Cifar10":
        dataset = torchvision.datasets.CIFAR10
        img_width = 32
        img_height = 32

        transform_train = Compose([
            Resize((img_width, img_height)),
            transforms.ToTensor(),
        ])
        transform_test = Compose([
            Resize((img_width, img_height)),
            ToTensor()
        ])
        clean_trainset = dataset(args.datasets_root_dir, train=True, transform=transform_train, download=True)
        clean_testset = dataset(args.datasets_root_dir, train=False, transform=transform_test, download=True)


        target_label = 0
        poisoned_transform_train_index = 0
        poisoned_transform_test_index = 0

    elif args.dataset == "GTSRB":
        import os.path as osp
        import cv2

        img_width, img_height = 32, 32

        datasets_root_dir = "./data"


        if args.attack_method in ["BadNet", "Blend", "ISSBA"]:
            #BadNet
            transform_train = Compose([
                ToPILImage(),
                Resize((img_width, img_height)),
                ToTensor()
            ])

            transform_test = Compose([
                ToPILImage(),
                Resize((img_width, img_height)),
                ToTensor()
            ])

            target_label = 1
            poisoned_transform_train_index = 2
            poisoned_transform_test_index = 2

        elif args.attack_method in ["WaNet"]:

            # WaNet
            transform_train = Compose([
                ToTensor(),
                RandomHorizontalFlip(),
                transforms.ToPILImage(),
                transforms.Resize((32, 32)),
                ToTensor()
            ])


            transform_test = Compose([
                ToTensor(),
                transforms.ToPILImage(),
                transforms.Resize((32, 32)),
                ToTensor()

            ])
            poisoned_transform_train_index = 0
            poisoned_transform_test_index = 0
        else:
            raise NotImplementedError

        clean_trainset = DatasetFolder(
            root=osp.join(datasets_root_dir, 'GTSRB', 'Train'),  # please replace this with path to your training set
            loader=cv2.imread,
            extensions=('ppm',),
            transform=transform_train,
            target_transform=None,
            is_valid_file=None)




        clean_testset = DatasetFolder(
            root=osp.join(datasets_root_dir, 'GTSRB', 'Test'),  # please replace this with path to your test set
            loader=cv2.imread,
            extensions=('png',),
            transform=transform_test,
            target_transform=None,
            is_valid_file=None)

    elif args.dataset == "ImageNet_Subset":
        import cv2
        datasets_root_dir = "./data/imagenette2-160"
        img_width, img_height = 224, 224
        transform_train = Compose([
            ToPILImage(),
            RandomHorizontalFlip(),

            Resize((img_width, img_height)),
            ToTensor()
        ])
        transform_test = Compose([
            ToPILImage(),
            Resize((img_width, img_height)),
            ToTensor()
        ])

        clean_testset = DatasetFolder(root=os.path.join(datasets_root_dir, 'val'),
                                  transform=transform_test,
                                  loader=cv2.imread,
                                  extensions=('jpeg',),
                                  target_transform=None,
                                  is_valid_file=None,
                                  )
        clean_trainset = DatasetFolder(root=os.path.join(datasets_root_dir, 'train'),
                                       transform=transform_train,
                                       loader=cv2.imread,
                                       extensions=('jpeg',),
                                       target_transform=None,
                                       is_valid_file=None,
                                       )

        target_label = 0
        poisoned_transform_train_index = 3
        poisoned_transform_test_index = 2


    else:
        raise NotImplementedError("Please specify a dataset name")

    if args.attack_method == "BadNet":

        pattern = torch.zeros((img_width, img_height), dtype=torch.uint8)

        # for Cifar10
        if args.dataset in ["Cifar10", "GTSRB"]:
            pattern[-3, -3] = 255
            pattern[-3, -2] = 0
            pattern[-3, -1] = 255
            pattern[-2, -3] = 0
            pattern[-2, -2] = 255
            pattern[-2, -2] = 0
            pattern[-1, -3] = 255
            pattern[-1, -2] = 0
            pattern[-1, -1] = 255
        elif args.dataset in ["ImageNet_Subset_"]:
            pattern[-int(img_width * 0.1):, -int(img_width * 0.1):] = 255
        else:
            pattern[-int(img_width * args.trigger_size):, -int(img_width * args.trigger_size):] = torch.rand(
                (int(img_width * args.trigger_size), int(img_width * args.trigger_size))) * 255

        # pattern[-int(img_width * args.trigger_size):, -int(img_width * args.trigger_size):] = 255

        # pattern[-int(img_width * args.trigger_size):, -int(img_width * args.trigger_size):] = torch.rand(
        #     (int(img_width * args.trigger_size), int(img_width * args.trigger_size))) * 255

        weight = torch.zeros((img_width, img_height), dtype=torch.float32)
        weight[-int(img_width * args.trigger_size):, -int(img_width * args.trigger_size):] = 1.0

        attack = core.BadNets(
            train_dataset=clean_trainset,
            test_dataset=clean_testset,
            model=core.models.ResNet(18),
            loss=nn.CrossEntropyLoss(),
            y_target=target_label,
            poisoned_rate=args.poisoned_rate,
            pattern=pattern,
            weight=weight,
            poisoned_transform_train_index = poisoned_transform_train_index, # GTSRB
            poisoned_transform_test_index = poisoned_transform_test_index # GTSRB
        )
        poisoned_trainset, poisoned_testset = attack.get_poisoned_dataset()




        # for idx in range(20):
        #     plt.imshow(poisoned_trainset[idx][0].permute(1,2,0))
        #     plt.show()
        #     input()
    elif args.attack_method == "LC":
        schedule = {
            'device': 'GPU',
            'CUDA_VISIBLE_DEVICES': "1",
            'GPU_num': 1,

            'benign_training': False,  # Train Attacked Model
            'batch_size': 128,
            'num_workers': 8,

            'lr': 0.1,
            'momentum': 0.9,
            'weight_decay': 5e-4,
            'gamma': 0.1,
            'schedule': [150, 180],

            'epochs': 200,

            'log_iteration_interval': 100,
            'test_epoch_interval': 10,
            'save_epoch_interval': 10,
        }

        pattern = torch.zeros((32, 32), dtype=torch.uint8)
        pattern[-1, -1] = 255
        pattern[-1, -3] = 255
        pattern[-3, -1] = 255
        pattern[-2, -2] = 255

        pattern[0, -1] = 255
        pattern[1, -2] = 255
        pattern[2, -3] = 255
        pattern[2, -1] = 255

        pattern[0, 0] = 255
        pattern[1, 1] = 255
        pattern[2, 2] = 255
        pattern[2, 0] = 255

        pattern[-1, 0] = 255
        pattern[-1, 2] = 255
        pattern[-2, 1] = 255
        pattern[-3, 0] = 255

        weight = torch.zeros((32, 32), dtype=torch.float32)
        weight[:3, :3] = 1.0
        weight[:3, -3:] = 1.0
        weight[-3:, :3] = 1.0
        weight[-3:, -3:] = 1.0

        eps = 8
        alpha = 1.5
        steps = 100
        max_pixel = 255

        print(transform_train)
        attack = core.LabelConsistent(
            train_dataset=clean_trainset,
            test_dataset=clean_testset,
            model=core.models.ResNet(18),
            adv_model=core.models.ResNet(18),
            adv_dataset_dir=f'./adv_dataset/CIFAR-10_eps{eps}_alpha{alpha}_steps{steps}_poisoned_rate{args.poisoned_rate}_seed{global_seed}',
            loss=nn.CrossEntropyLoss(),
            y_target=target_label,
            poisoned_rate=0.3,
            pattern=pattern,
            weight=weight,
            eps=eps,
            alpha=alpha,
            steps=steps,
            max_pixel=max_pixel,
            poisoned_transform_train_index=0,
            poisoned_transform_test_index=0,
            poisoned_target_transform_index=0,
            schedule=schedule,
            seed=global_seed,
            deterministic=True
        )
        # attack.train()
        # input()
        poisoned_trainset, poisoned_testset = attack.get_poisoned_dataset()
    elif args.attack_method == "Refool":
        # load reflection images
        reflection_images = []
        reflection_data_dir = "/home/xxxx/BackdoorBox/data/VOCdevkit/VOC2012/JPEGImages/"  # please replace this with path to your desired reflection set
        reflection_image_path = os.listdir(reflection_data_dir)
        reflection_images = [read_image(os.path.join(reflection_data_dir, img_path)) for img_path in
                             reflection_image_path[:200]]
        attack = core.Refool(
            train_dataset=clean_trainset,
            test_dataset=clean_testset,
            model=core.models.ResNet(18),
            loss=nn.CrossEntropyLoss(),
            y_target=target_label,
            poisoned_rate=args.poisoned_rate,
            poisoned_transform_train_index=poisoned_transform_train_index,
            poisoned_transform_test_index=poisoned_transform_test_index,
            poisoned_target_transform_index=0,
            schedule=None,
            seed=global_seed,
            deterministic=True,
            reflection_candidates=reflection_images,
        )
        poisoned_trainset, poisoned_testset = attack.get_poisoned_dataset()

    elif args.attack_method == "WaNet":
        identity_grid, noise_grid = gen_grid(img_width, int(img_height/8), 50 if args.dataset == "ImageNet_Subset" else 1)
        attack = core.WaNet(
            train_dataset=clean_trainset,
            test_dataset=clean_testset,
            model=None,
            loss=None,
            y_target=0,
            poisoned_rate=args.poisoned_rate,
            identity_grid=identity_grid,
            noise_grid=noise_grid,
            noise=False,
            poisoned_transform_train_index=poisoned_transform_train_index,  # ImageNet SUbset出问题，所以加上了这个，以后可能删掉
            poisoned_transform_test_index=poisoned_transform_test_index  # ImageNet SUbset出问题，所以加上了这个，以后可能删掉
        )
        poisoned_trainset, poisoned_testset = attack.get_poisoned_dataset()



    elif args.attack_method == "Blend":
        # pattern = torch.zeros((1, 32, 32), dtype=torch.uint8)
        # pattern[0, -3:, -3:] = 255
        import cv2
        pattern = cv2.imread('image.png')
        pattern = torch.from_numpy(cv2.resize(pattern, (img_width, img_height))).permute(2, 0, 1)
        print(pattern.shape)

        weight = torch.zeros((1, img_width, img_height), dtype=torch.float32)
        weight[:, int(0.2 * img_width):, int(0.2 * img_height):] = 0.8
        # weight[:, :int(img_width), :int(img_height)] = 0.2
        attack = core.Blended(
            train_dataset=clean_trainset,
            test_dataset=clean_testset,
            model=core.models.ResNet(18),
            loss=nn.CrossEntropyLoss(),
            pattern=pattern,
            weight=weight,
            y_target=1,
            poisoned_rate=args.poisoned_rate,
            seed=global_seed,
            deterministic=True,
            poisoned_transform_train_index=poisoned_transform_train_index,  # GTSRB
            poisoned_transform_test_index=poisoned_transform_test_index  # GTSRB
        )
        poisoned_trainset, poisoned_testset = attack.get_poisoned_dataset()

    elif args.attack_method == "ISSBA":

        if not os.path.exists("ISSBA_poisoned_trainset_{}.pth".format(args.dataset)):
            secret_size = 20

            train_data_set = []
            train_secret_set = []
            for idx, (img, lab) in enumerate(clean_trainset):
                train_data_set.append(img.tolist())
                secret = np.random.binomial(1, .5, secret_size).tolist()
                train_secret_set.append(secret)

            for idx, (img, lab) in enumerate(clean_testset):
                train_data_set.append(img.tolist())
                secret = np.random.binomial(1, .5, secret_size).tolist()
                train_secret_set.append(secret)

            train_steg_set = GetPoisonedDataset(train_data_set, train_secret_set)

            encoder_schedule = {
                'secret_size': secret_size,
                'enc_height': 32,
                'enc_width': 32,
                'enc_in_channel': 3,
                'enc_total_epoch': 20,
                'enc_secret_only_epoch': 2,
                'enc_use_dis': False,
            }

            schedule = {
                'device': 'GPU',
                'GPU_num': 1,

                'benign_training': False,
                'batch_size': 128,
                'num_workers': 8,

                'lr': 0.1,
                'momentum': 0.9,
                'weight_decay': 5e-4,
                'gamma': 0.1,
                'schedule': [150, 180],

                'epochs': 0,

                'log_iteration_interval': 100,
                'test_epoch_interval': 10,
                'save_epoch_interval': 100,

                # 'pretrain': "ResNet18_ISSBA.pth"
            }



            attack = core.ISSBA(
                dataset_name="Cifar10", # to avoid normalizer
                train_dataset=clean_trainset,
                test_dataset=clean_testset,
                train_steg_set=train_steg_set,
                model=core.models.ResNet(18),
                loss=nn.CrossEntropyLoss(),
                y_target=0,
                poisoned_rate=args.poisoned_rate,  # follow the default configure in the original paper
                encoder_schedule=encoder_schedule,
                encoder=None,
                seed=global_seed,
                schedule=schedule
            )


            attack.train(schedule=schedule)

            poisoned_trainset, poisoned_testset = attack.get_poisoned_dataset()

            torch.save(poisoned_trainset, "ISSBA_poisoned_trainset_{}.pth".format(args.dataset))
            torch.save(poisoned_testset, "ISSBA_poisoned_testset_{}.pth".format(args.dataset))
        else:
            poisoned_trainset = torch.load("ISSBA_poisoned_trainset_{}.pth".format(args.dataset))
            poisoned_testset = torch.load("ISSBA_poisoned_testset_{}.pth".format(args.dataset))

            for img, target in poisoned_trainset:
                print(target, type(poisoned_trainset))
                break

    else:
        raise NotImplementedError("Please specify an attack method")

    return clean_trainset, clean_testset, poisoned_trainset, poisoned_testset


def train_model(args, poisoned_trainset, poisoned_testset, clean_testset, use_saved=True):
    if args.model_name == "ResNet18":
        model = core.models.ResNet(18)
    elif args.model_name == "ResNet18-64":
        resnet18 = models.resnet18(pretrained=True)
        resnet18.fc = torch.nn.Linear(512, 200)
        model = resnet18
    elif args.model_name == "ResNet18-GTSRB":
        model = core.models.ResNet(18, 43)
    elif args.model_name == "ResNet18-ImageNet":
        model = models.resnet18(pretrained=True)
        model.fc = torch.nn.Linear(512, 10)
    elif args.model_name == "EfficientNet-b3":
        from efficientnet_pytorch import EfficientNet
        model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=200)
    elif args.model_name == "EfficientNet-b0":
        from efficientnet_pytorch import EfficientNet
        model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=200)
    else:
        raise NotImplementedError

    model = model.to(device)
    print(torch.cuda.device_count())

    if use_saved:
        try:
            # model.load_state_dict(torch.load("{}_{}_{}_{}.pth".format(args.model_name, args.attack_method, args.poisoned_rate, args.trigger_size)))
            # model.load_state_dict(torch.load(
            #     "{}_{}_{}.pth".format(args.model_name, args.attack_method, args.poisoned_rate)))
            model.load_state_dict(
                torch.load(os.path.join(args.model_root, "{}_{}.pth".format(args.model_name, args.attack_method))))
            return model
        except FileNotFoundError:
            print("Not Found, retrain the model")

    poisoned_trainset_loader = DataLoader(poisoned_trainset, batch_size=128, shuffle=True) # 128 for cifar10
    criterion = torch.nn.CrossEntropyLoss()


    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # 0.1 for cifar10

    if args.model_name == "EfficientNet-b0":
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 150, 180], gamma=0.1)

    if args.model_name == "EfficientNet-b0":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 100, 180], gamma=0.5)

    running_loss = 0.0
    model.train()
    # model = model.to(device)
    for epoch in tqdm(range(args.epoch_number)):
        for idx, (imgs, labels) in enumerate(poisoned_trainset_loader):
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            # print(model.device)
            preds = model(imgs)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if idx % 50 == 0:  # print every 2000 mini-batches
                print(f'[{epoch + 1}, {idx + 1:5d}] loss: {running_loss / len(imgs):.3f}')
                running_loss = 0.0
        evaluate(model, poisoned_testset, mode="poisoned")
        evaluate(model, clean_testset, mode="clean")
        model.train()
        scheduler.step()

        # torch.save(model.state_dict(), "{}_{}_{}_{}.pth".format(args.model_name, args.attack_method, args.poisoned_rate, args.trigger_size))
        torch.save(model.state_dict(),
                   os.path.join(args.model_root, "{}_{}.pth".format(args.model_name, args.attack_method)))
    print("saved")
    return model





def evaluate(model, dataset, alpha=0, mode="Poisoned"):
    correct = 0
    total = 0
    testloader = DataLoader(dataset, batch_size=2048) #1024 for cifar10
    model.eval()
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            images = alpha * torch.rand(images.shape, device=device) + images
            # calculate outputs by running images through the network
            print(images.device, next(model.parameters()).device)
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.topk(outputs.data, 2, 1)

            predicted = predicted[:, 0]
            total += labels.size(0)

            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the ' + mode + f' images: {100 * correct // total} %')


def mixup_detect(model, clean_testset, poisoned_testset, alpha_range):
    model.eval()
    #
    # alpha_range = [0, 0.04, 0.6, 0.8, 0.8, 0.8, 0.8]  # Cifar10
    # if args.dataset == "GTSRB":
    #     alpha_range = [0, 0.2, 1.2, 1.5, 1.5, 1.5, 1, 0.5] # GTSRB
    # elif args.dataset  == "Cifar10":
    #     # alpha_range = [0, 0.01, 0.05, 0.13, 0.21, 0.29, 0.37, 0.43, 0.49]  # Cifar10
    #     alpha_range = [0, 0.04, 0.13, 0.21, 0.29, 0.37, 0.43, 0.49]
    #     alpha_range = [0, 0.15, 0.3, 0.45, 0.6, 0.75]
    #     alpha_range = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    #     # alpha_range = [0, 0.25, 0.5, 0.75, 1.0, 1.25]
    #     # alpha_range = [0, 0.001, 0.005,  0.01, 0.02, 0.03, 0.13, 0.21, 0.29]
    # elif args.dataset == "ImageNet_Subset":
    #     alpha_range = [0, 0.05, 0.08, 0.25, 0.32, 0.32]
    #     alpha_range = [0, 0.02, 0.04, 0.2, 0.4, 0.6, 0.8, 1]
    #
    # alpha_range = [0, 0.9, 0.9, 0.9, 0.9, 0.9] # ImageNet
    # alpha_range = np.arange(0, 0.6, 0.06) # range of the alpha value Cifar10 0 0.6 0.06
    # GTSRB 0 3 0.5

    for i in alpha_range:
        evaluate(model, poisoned_testset, alpha=i)

        evaluate(model, clean_testset, alpha=i, mode="clean")

    scores_poi = []
    scores_clean = []

    channel_num = clean_testset[0][0].shape[0]
    width = clean_testset[0][0].shape[1]
    height = clean_testset[0][0].shape[2]

    random_noises = torch.zeros((len(alpha_range), channel_num, width, height)).to(device)
    for idx, i in enumerate(alpha_range):
        torch.manual_seed(i)
        noise = torch.rand((channel_num, width, height))
        random_noises[idx, :, :, :] = noise




    alpha_range_ = torch.tensor(alpha_range).view(-1, 1, 1, 1).to(device)

    for idx, (img, label) in tqdm(enumerate(poisoned_testset)):

        image_batch = img.to(device) + alpha_range_ * random_noises
        image_batch = image_batch.type(torch.cuda.FloatTensor)
        preds = torch.max(model(image_batch), 1)
        score = len(alpha_range)-1
        for i in range(len(alpha_range)):
            if preds.indices[i] != preds.indices[0]:
                break
            else:
                score = i

        scores_poi.append(score)
        dist = torch.distributions.normal.Normal(0.5, 1)
    evaluate(model, clean_testset, alpha=0)

    for idx, (img, label) in tqdm(enumerate(clean_testset)):
        image_batch = torch.ones((len(alpha_range), channel_num, width, height)) * img
        image_batch = image_batch.to(device)
        image_batch = image_batch + alpha_range_ * random_noises
        image_batch = image_batch.type(torch.cuda.FloatTensor)
        preds = torch.max(model(image_batch), 1)

        score = len(alpha_range)-1
        for i in range(len(alpha_range)):
            if preds.indices[i] != preds.indices[0]:
                break
            else:
                score = i

        scores_clean.append(score)


    return scores_poi, scores_clean

def AUROC_Score(pred_in, pred_out, file):

    y_in = [0]*len(pred_in)
    y_out = [1]*len(pred_out)

    y = y_in + y_out

    pred = pred_in + pred_out
    fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=1)
    plt.plot(fpr, tpr, label=file)
    plt.savefig(file+".png",bbox_inches='tight')
    return roc_auc_score(y, pred)


def scaleup(model, clean_testset, poisoned_testset):
    model.eval()

    alpha_range = np.arange(1, 11, 2) # range of the alpha value Cifar10 0 0.6 0.06

    evaluate(model, poisoned_testset, alpha=0)


    scores_poi = []
    scores_clean = []

    channel_num = clean_testset[0][0].shape[0]
    width = clean_testset[0][0].shape[1]
    height = clean_testset[0][0].shape[2]

    alpha_range_ = torch.tensor(alpha_range).view(-1, 1, 1, 1).to(device)

    for idx, (img, label) in tqdm(enumerate(poisoned_testset)):
        image_batch = alpha_range_ * img.to(device)
        image_batch = torch.clamp(image_batch, 0, 1).type(torch.cuda.FloatTensor)
        preds = torch.max(model(image_batch), 1)
        score = np.mean((preds.indices == preds.indices[0]).cpu().numpy())



        scores_poi.append(score)

    evaluate(model, clean_testset, alpha=0)

    for idx, (img, label) in tqdm(enumerate(clean_testset)):
        image_batch = alpha_range_ * img.to(device)
        image_batch = torch.clamp(image_batch, 0, 1).type(torch.cuda.FloatTensor)
        preds = torch.max(model(image_batch), 1)
        score = np.mean((preds.indices == preds.indices[0]).cpu().numpy())
        scores_clean.append(score)


    return scores_poi, scores_clean


def precision_recall(predictions, targets):
    precision = precision_score(targets, predictions)
    recall = recall_score(targets, predictions)
    return precision, recall

def extract_dataset(args):
    clean_testset = DatasetNumpy(args.existing_dataset_path + "clean_testset.npy", args.existing_dataset_path + "_clean_testset")
    poisoned_testset = DatasetNumpy(args.existing_dataset_path + "poisoned_testset.npy", args.existing_dataset_path + "_poisoned_testset")

    return clean_testset, poisoned_testset

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
                    prog='MixUpDetection',
                    description='')
    parser.add_argument("--dataset", default="GTSRB")
    parser.add_argument("--attack_method", default="BadNet")
    parser.add_argument("--datasets_root_dir",  default= 'data')
    parser.add_argument("--model_name", default="ResNet18-GTSRB")
    parser.add_argument("--epoch_number", type=int, default=200)
    parser.add_argument("--poisoned_rate", type=float, default=0.1)
    parser.add_argument("--trigger_size", type=float, default=0.1)
    parser.add_argument("--model_root", type=str, default="./models")

    parser.add_argument("--use_existing_dataset", type=bool, default=False)
    parser.add_argument("--use_existing_model", type=bool, default=False)
    parser.add_argument("--existing_model_path", type=str, default="")
    parser.add_argument("--existing_dataset_path", type=str, default="")

    
    args = parser.parse_args()

    device = torch.device("cuda:0")

    global_seed = 0

    torch.manual_seed(global_seed)



    if not args.use_existing_dataset:
        clean_trainset, clean_testset, poisoned_trainset, poisoned_testset = prepare_dataset(args)
    else:
        clean_trainset, _, poisoned_trainset, _ = prepare_dataset(args)
        clean_testset, poisoned_testset = extract_dataset(args)

    if not args.use_existing_model:
        model = train_model(args, poisoned_trainset, poisoned_testset, clean_testset, use_saved=True)
    else:
        model = torch.load(args.existing_model_path, map_location=device)
        model.eval()

    evaluate(model, clean_testset, mode="Clean")
    evaluate(model, poisoned_testset, mode="Poisoned")




    # print("====================={}=====================".format("LAVA"))
    
    # import time
    
    # start = time.time()
    # plt.figure()

    # counter = {}
    # chosen_imgs_idx = []
    # for idx, (img, label) in enumerate(clean_trainset):
    #     if (counter.get(label, 0) >= 5):
    #         continue
    #     chosen_imgs_idx.append(idx)
    #     counter[label] = counter.get(label, 0) + 1
    
    # print(chosen_imgs_idx)
    
    # lava_ = LAVA(args, poisoned_testset, clean_testset, Subset(clean_trainset, chosen_imgs_idx), device)
    # scores_poi, scores_clean = lava_.verify()
    # plt.hist(scores_poi, color="red", alpha=0.6)
    # plt.hist(scores_clean, color="blue", alpha=0.6, bins=30)
    # plt.show()
    
    # auc = AUROC_Score(scores_poi, scores_clean, "LAVA_cifar")
    
    # predictions = []
    # for i in scores_poi + scores_clean:
    #     if i >= 0:
    #         predictions.append(1)
    #     else:
    #         predictions.append(0)
    # targets = [1] * len(scores_poi) + [0] * len(scores_clean)
    # precision, recall = precision_recall(predictions, targets)
    
    # print(precision, recall, auc)
    
    # print("LAVA ", time.time() - start)
    
    
    
    # print("====================={}=====================".format("STRIP"))
    # start = time.time()
    # plt.figure()
    # x_clean_trainset = np.stack([img.permute(1, 2, 0).numpy() for (img, label) in clean_trainset])
    # strip = STRIP(x_clean_trainset, model, poisoned_testset, clean_testset, device)
    # scores_poi, scores_clean = strip.cal_score()
    # plt.hist(scores_poi, color="red", alpha=0.6)
    # plt.hist(scores_clean, color="blue", alpha=0.6, bins=30)
    # plt.show()
    # auc = AUROC_Score(scores_poi, scores_clean, "cifar")
    
    # predictions = []
    
    # for i in scores_poi + scores_clean:
    #     if i <= -20: # 1E-3 0.45 for Cifar10
    #         predictions.append(1)
    #     else:
    #         predictions.append(0)
    # targets = [1] * len(scores_poi) + [0] * len(scores_clean)
    # precision, recall = precision_recall(predictions, targets)
    
    # print(precision, recall, auc)
    
    # print("STRIP ", time.time() - start)
    
    # print("====================={}=====================".format("Frequency"))
    # start = time.time()
    # plt.figure()
    
    # counter = {}
    # chosen_imgs = []
    # for (img, label) in clean_trainset:
    #     if(counter.get(label, 0) >= 2):
    #         continue
    #     chosen_imgs.append(img.permute(1,2,0).numpy())
    #     counter[label] = counter.get(label, 0) + 1
    
    # print(counter)
    
    # x_clean_trainset = np.stack(chosen_imgs)
    # frequency_detector = Frequency(args, poisoned_testset, clean_testset, x_clean_trainset)
    # scores_poi, scores_clean = frequency_detector.verify()
    
    # auc = AUROC_Score(scores_poi, scores_clean, "Frequency_cifar")
    # plt.hist(scores_poi, color="red", alpha=0.6)
    # plt.hist(scores_clean, color="blue", alpha=0.6, bins=30)
    # plt.show()
    # #
    # predictions = []
    # for i in scores_poi + scores_clean:
    #     if i >= 0.5:
    #         predictions.append(1)
    #     else:
    #         predictions.append(0)
    # targets = [1] * len(scores_poi) + [0] * len(scores_clean)
    # precision, recall = precision_recall(predictions, targets)
    
    # print(precision, recall, auc)
    
    # print("Frequency ", time.time() - start)
    #
    print("====================={}=====================".format("SCALE-UP"))
    start = time.time()
    plt.figure()
    scores_poi, scores_clean = scaleup(model, clean_testset, poisoned_testset)
    plt.hist(scores_poi, color="red", alpha=0.6)
    plt.hist(scores_clean, color="blue", alpha=0.6, bins=30)
    plt.show()
    auc = AUROC_Score(scores_poi, scores_clean, "SCALE-UP_cifar")
    plt.show()
    
    predictions = []
    for i in scores_poi + scores_clean:
        if i >=  (min(scores_poi) + max(scores_clean)) / 2:
            predictions.append(1)
        else:
            predictions.append(0)
    targets = [1] * len(scores_poi) + [0] * len(scores_clean)
    precision, recall = precision_recall(predictions, targets)
    
    print(precision, recall, auc)
    print("SCALE-UP ", time.time() - start)
    
    print("====================={}=====================".format("BBCaL"))
    start = time.time()
    alpha_range = np.arange(0, (0.25) * 7 + 0.01, 0.25)
    print("alpha_range", alpha_range)
    scores_poi, scores_clean = mixup_detect(model, clean_testset, poisoned_testset, alpha_range)
    plt.figure()
    plt.hist(scores_poi, color="red", alpha=0.6)
    plt.hist(scores_clean, color="blue", alpha=0.6, bins=30)
    plt.show()
    auc = AUROC_Score(scores_poi, scores_clean, "BBCaL_cifar")
    
    predictions = []
    for i in scores_poi + scores_clean:
        if i <= 0.028 or i > 6.8:  # 1E-3 0.45 for Cifar10
            predictions.append(1)
        else:
            predictions.append(0)
    targets = [1] * len(scores_poi) + [0] * len(scores_clean)
    precision, recall = precision_recall(predictions, targets)
    
    print(precision, recall, auc)
    # print(scores_clean)
    print("alpha_range", alpha_range)
    
    print("BBCaL ", time.time() - start)





