import argparse
import torch
from torchvision import transforms
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
import numpy as np
import os
import random


class logger(object):
    def __init__(self, path, name='log.txt'):
        self.path = path
        self.name = name

    def info(self, msg):
        print(msg)
        with open(os.path.join(self.path, self.name), 'a') as f:
            f.write(msg + "\n")

def get_args():
    parser = argparse.ArgumentParser(description='DETECTION-NEW-BACKDOOR')
    parser.add_argument('--experiment', type=str, required=True, help='name of experiment')

    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet',
                                                                 'miniimagenet', 'imagenet100'],
                        help='the dataset used in experiment')
    parser.add_argument('--data', type=str, default='data/CIFAR10', help='the directory of dataset')
    parser.add_argument('--num-classes', default=10, type=int, help='the number of classes in the dataset')
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--num-workers', type=int, default=4)

    parser.add_argument('--poison-path', type=str, default=None, help='the path of pretrained poison')

    parser.add_argument('--watermark-path', type=str, default=None, help='the path of the watermark')

    parser.add_argument('--poison-size', type=int, default=32,
                        help='the image size of poisons')

    parser.add_argument('--seed', default=1, type=int, help='random seed')

    parser.add_argument('--gpu-id', type=str, default='0', help='the gpu id')

    parser.add_argument('--watermark-budget', type=float, default=8, help='the watermark budget')

    parser.add_argument('--wm-length', type=int, default=2000, help='the watermarking length')
    parser.add_argument('--mask-type', default='fixed', choices=['random', 'fix-lt',
                                                                 'fix-lb', 'fix-rt', 'fix-rb'],
                        help='the type of mask for pixels')
    parser.add_argument('--post-poisoning', action='store_true',
                        help='if generate post-poisoning watermark')
    parser.add_argument('--detect-type', default='corresponding', choices=['random', 'corresponding'],
                        help='the type of mask for pixels')

    arguments = parser.parse_args()
    arguments.watermark_budget = arguments.watermark_budget / 25

    return arguments

def setup_seed(seed: int):
    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # Numpy
    np.random.seed(seed)
    # Python
    random.seed(seed)


def detection(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    if args.post_poisoning:
        save_dir = os.path.join('./detection-post-poisoning', args.experiment)
    else:
        save_dir = os.path.join('./detection-poisoning-concurrent', args.experiment)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    log = logger(path=save_dir)
    log.info(str(args))


    try:
        poisoned_dataset = torch.load(args.poison_path, map_location='cpu')
        log.info(f'poison: {torch.max(poisoned_dataset[0][0])}')
    except:
        raise {'no poisoned dataset founded!'}

    setup_seed(args.seed)

    if args.post_poisoning:
        if args.mask_type == 'random':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.zeros(1, C, H, W).cuda()

            num_pixels = C * H * W
            num_unmasked = int(num_pixels - args.wm_length)

            flat_mask = mask[0].view(-1)
            unmasked_indices = torch.randperm(num_pixels)[:num_unmasked]
            flat_mask[unmasked_indices] = 1

            mask = flat_mask.view(1, C, H, W).cuda()


        elif args.mask_type == 'fix-lt':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W  # Total number of pixels in C * H * W dimensions
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-lb':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(H - int(rat * H), H):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-rt':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(W - int(rat * W), W):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-rb':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(H - int(rat * H), H):
                    for w in range(W - int(rat * W), W):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        else:
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W).cuda()


        watermark = torch.where(torch.randn(1, C, H, W) < 0, -args.watermark_budget, args.watermark_budget) * (
                torch.ones(1, C, H, W) - mask.cpu())
        rand_signs = torch.randint(0, 2, watermark.shape, dtype=watermark.dtype) * 2 - 1
        if args.detect_type == 'random':
            key = rand_signs
        else:
            key = torch.where(
                watermark < -1e-6, -1,
                torch.where(
                    watermark > 1e-6, 1,
                    rand_signs
                )
            )
    else:
        try:
            watermark = torch.load(args.watermark_path)
            if args.detect_type == 'random':
                rand_signs = torch.randint(0, 2, watermark.shape, dtype=watermark.dtype) * 2 - 1
                key = rand_signs
            else:
                key = torch.where(watermark < -1e-6, -1, torch.where(watermark > 1e-6, 1, 0))
            log.info(f'the number of positive unmasked key pixels is:{torch.sum(key > 0)}')
            log.info(f'the number of unmasked key pixels is:{torch.nonzero(key).size(0)}')
        except:
            watermark = None
            log.info('no watermark founded!')
            raise ValueError

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])



    loader = DataLoader(poisoned_dataset, batch_size=1, shuffle=False)


    all_labels = []
    all_scores = []

    def process_loader(loader):
        for images, _, label in loader:
            images = normalize(images)
            images = images.view(images.size(0), -1)

            key_flat = key.view(-1)
            if key_flat.size(0) != images.size(1):
                key_resized = torch.nn.functional.interpolate(
                    key.unsqueeze(0).unsqueeze(0), size=images.size(1), mode='nearest'
                ).squeeze(0).squeeze(0)
            else:
                key_resized = key_flat

            key_resized = key_resized.to(images.device)

            detection_scores = torch.sum(key_resized * images, dim=1).cpu().numpy()

            all_labels.extend((1-label).cpu().numpy())
            all_scores.extend(detection_scores)

    process_loader(loader)

    all_labels = np.array(all_labels)
    all_scores = np.array(all_scores)
    print(all_labels)
    print(all_scores)


    auc = roc_auc_score(all_labels, all_scores)
    log.info(f'AUROC: {auc}')



if __name__ == '__main__':
    args = get_args()
    if args.dataset == 'cifar10':
        args.poison_size = 32
    if args.dataset == 'cifar100':
        args.poison_size = 32
    if args.dataset == 'tinyimagenet':
        args.poison_size = 64
    if args.dataset == 'miniimagenet':
        args.poison_size = 84
    if args.dataset == 'imagenet100':
        args.poison_size = 224


    detection(args)