import argparse
import torch
from datasets.cifar import CIFAR10Index
from torchvision import transforms
from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score
from torch.utils.data import DataLoader
import numpy as np
import os
from utils import logger


def get_args():
    parser = argparse.ArgumentParser(description='DETECTION')
    parser.add_argument('--experiment', type=str, required=True, help='name of experiment')

    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet',
                                                                 'miniimagenet', 'imagenet100'],
                        help='the dataset used in experiment')
    parser.add_argument('--data', type=str, default='data/CIFAR10', help='the directory of dataset')
    parser.add_argument('--num-classes', default=10, type=int, help='the number of classes in the dataset')
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--num-workers', type=int, default=4)

    parser.add_argument('--poison-path', type=str, default=None, help='the path of pretrained poison')

    parser.add_argument('--watermark-path', type=str, default=None, help='the path of the watermark')

    parser.add_argument('--poison-size', type=int, default=32,
                        help='the image size of poisons')

    parser.add_argument('--seed', default=1, type=int, help='random seed')

    parser.add_argument('--gpu-id', type=str, default='0', help='the gpu id')

    parser.add_argument('--poison-budget', type=float, default=8, help='the poison budget')

    parser.add_argument('--watermark-budget', type=float, default=8, help='the watermark budget')

    parser.add_argument('--wm-length', type=int, default=2000, help='the watermarking length')
    parser.add_argument('--mask-type', default='fixed', choices=['random', 'fix-lt',
                                                                 'fix-lb', 'fix-rt', 'fix-rb'],
                        help='the type of mask for pixels')
    parser.add_argument('--post-poisoning', action='store_true',
                        help='if generate post-poisoning watermark')

    parser.add_argument('--detect-type', default='corresponding', choices=['random', 'corresponding'],
                        help='the type of mask for pixels')

    arguments = parser.parse_args()
    arguments.watermark_budget = arguments.watermark_budget / 255
    arguments.poison_budget = arguments.poison_budget / 255

    return arguments


def detection(args):
    if args.post_poisoning:
        save_dir = os.path.join('eval', args.dataset, 'detection-post-poisoning', args.experiment)
    else:
        save_dir = os.path.join('eval', args.dataset, 'detection-poisoning-concurrent', args.experiment)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    log = logger(path=save_dir)
    log.info(str(args))


    try:
        poison = torch.load(args.poison_path, map_location='cpu')
        log.info(f'poison: {torch.max(poison)}')
    except:
        if args.poison_path == 'random':
            torch.manual_seed(1)
            poison = torch.randn([50000, 3, args.poison_size, args.poison_size]).uniform_(-args.poison_budget, args.poison_budget)
            log.info('use random noise')
            log.info(f'poison:{poison}')
        else:
            poison = None
            log.info('no poison founded!')


    if args.post_poisoning:
        if args.mask_type == 'random':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.zeros(1, C, H, W).cuda()

            num_pixels = C * H * W
            num_unmasked = int(num_pixels - args.wm_length)


            flat_mask = mask[0].view(-1)
            unmasked_indices = torch.randperm(num_pixels)[:num_unmasked]
            flat_mask[unmasked_indices] = 1

            mask = flat_mask.view(1, C, H, W).cuda()

        elif args.mask_type == 'fix-lt':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-lb':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(H - int(rat * H), H):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-rt':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(W - int(rat * W), W):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-rb':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(H - int(rat * H), H):
                    for w in range(W - int(rat * W), W):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        else:
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W).cuda()

        watermark = torch.where(torch.randn(1, C, H, W) < 0, -args.watermark_budget, args.watermark_budget) * (
                torch.ones(1, C, H, W) - mask.cpu())
        rand_signs = torch.randint(0, 2, watermark.shape, dtype=watermark.dtype) * 2 - 1
        if args.detect_type == 'random':
            key = rand_signs
        else:
            key = torch.where(
                watermark < -1e-6, -1,
                torch.where(
                    watermark > 1e-6, 1,
                    rand_signs
                )
            )
        log.info(f'the number of positive unmasked key pixels is:{torch.sum(key > 0)}')
        log.info(f'the number of unmasked key pixels is:{torch.nonzero(key).size(0)}')

        log.info(f'the number of unmasked watermark pixels is:{torch.nonzero(watermark).size(0)}')
    else:
        try:
            watermark = torch.load(args.watermark_path)
            if args.detect_type == 'random':
                rand_signs = torch.randint(0, 2, watermark.shape, dtype=watermark.dtype) * 2 - 1
                key = rand_signs
            else:
                key = torch.where(watermark < -1e-6, -1, torch.where(watermark > 1e-6, 1, 0))
            log.info(f'the number of positive unmasked key pixels is:{torch.sum(key > 0)}')
            log.info(f'the number of unmasked key pixels is:{torch.nonzero(key).size(0)}')

            log.info(f'the number of unmasked watermark pixels is:{torch.nonzero(watermark).size(0)}')
        except:
            watermark = None
            log.info('no watermark founded!')
            raise ValueError

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([transforms.ToTensor(), normalize])



    eval_dataset_clean = CIFAR10Index(root=args.data, train=True, transform=transform, delta=None, download=True)

    if args.post_poisoning:
        eval_dataset_poison = CIFAR10Index(root=args.data, train=True, transform=transform, delta=poison+watermark, download=True)
    else:
        eval_dataset_poison = CIFAR10Index(root=args.data, train=True, transform=transform, delta=poison, download=True)



    loader_clean = DataLoader(eval_dataset_clean, batch_size=args.batch_size, shuffle=False)
    loader_poison = DataLoader(eval_dataset_poison, batch_size=args.batch_size, shuffle=False)

    all_labels = []
    all_scores = []

    def process_loader(loader, label):
        for images, targets, indices in loader:
            images = images.view(images.size(0), -1)

            key_flat = key.view(-1)
            if key_flat.size(0) != images.size(1):
                key_resized = torch.nn.functional.interpolate(
                    key.unsqueeze(0).unsqueeze(0), size=images.size(1), mode='nearest'
                ).squeeze(0).squeeze(0)
            else:
                key_resized = key_flat

            key_resized = key_resized.to(images.device)

            detection_scores = torch.sum(key_resized * images, dim=1).cpu().numpy()

            all_labels.extend([label] * images.size(0))
            all_scores.extend(detection_scores)

    process_loader(loader_clean, label=0)

    process_loader(loader_poison, label=1)

    all_labels = np.array(all_labels)
    all_scores = np.array(all_scores)

    auc = roc_auc_score(all_labels, all_scores)
    log.info(f'AUROC: {auc}')



if __name__ == '__main__':
    args = get_args()
    if args.dataset == 'cifar10':
        args.poison_size = 32
    if args.dataset == 'cifar100':
        args.poison_size = 32
    if args.dataset == 'tinyimagenet':
        args.poison_size = 64
    if args.dataset == 'miniimagenet':
        args.poison_size = 84
    if args.dataset == 'imagenet100':
        args.poison_size = 224


    detection(args)