import torch
import torch.backends.cudnn as cudnn
import os
import argparse
import time
from utils import AverageMeter, cosine_annealing, logger, save_checkpoint, evaluation, evaluation_detector, setup_seed, classwise_evaluation
from torch.nn import functional as F
from torchvision.models import resnet18, resnet50
from models import vgg19, DenseNet121, simple_classifier, WideResNet, SL_ViT
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
from cutout import Cutout
from datasets.cifar import CIFAR10WatermarkIndex
import kornia.augmentation as K
import numpy as np

def get_args():
    parser = argparse.ArgumentParser(description='DETECTOR')
    parser.add_argument('--experiment', type=str, required=True, help='name of experiment')
    parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'vgg19',
                                                                             'vit-b', 'densenet121', 'linear',
                                                                             '2nn', '3nn', 'lenet5', 'vit', 'wrn34-10'],
                        help='the model arch used in experiment')

    parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet',
                                                                 'miniimagenet', 'imagenet100'],
                        help='the dataset used in experiment')
    parser.add_argument('--data', type=str, default='data/CIFAR10', help='the directory of dataset')
    parser.add_argument('--num-classes', default=10, type=int, help='the number of classes in the dataset')
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--num-workers', type=int, default=4)

    parser.add_argument('--poison-path', type=str, default=None, help='the path of pretrained poison')

    parser.add_argument('--watermark-path', type=str, default=None, help='the path of the watermark')

    parser.add_argument('--poison-ratio', type=float, default=1.0, help='the poisoning ratio')
    parser.add_argument('--poison-size', type=int, default=32,
                        help='the image size of poisons')

    parser.add_argument('--optimizer', default='sgd', type=str,
                        help='the optimizer used in training')
    parser.add_argument('--epochs', default=50, type=int,
                        help='the number of total epochs to run')
    parser.add_argument('--lr', default=0.005, type=float, help='optimizer learning rate')
    parser.add_argument('--seed', default=None, type=int, help='random seed')

    parser.add_argument('--resume', action='store_true', help='if resume training')
    parser.add_argument('--gaussian-smooth', action='store_true', help='if use gaussian smooth')
    parser.add_argument('--random-noise', action='store_true', help='if use random noise')
    parser.add_argument('--gpu-id', type=str, default='0', help='the gpu id')

    parser.add_argument('--poisoned-class', default=-1, type=int,
                        help='which class could be poisoned, if all classes are poisoned, it is set to -1')

    parser.add_argument('--get-lr-process', action='store_true', help='if get learning process')

    parser.add_argument('--post-poisoning', action='store_true',
                        help='if generate post-poisoning watermark')
    parser.add_argument('--wm-length', type=int, default=2000, help='the watermarking length')

    parser.add_argument('--mask-type', default='fixed', choices=['random', 'fixed', 'fix-lt',
                                                                 'fix-lb', 'fix-rt', 'fix-rb'],
                        help='the type of mask for pixels')

    parser.add_argument('--aug_type', default='normal', choices=['none', 'normal', 'cutout', 'cutmix',
                                                                 'mixup', 'jitter', 'grayscale', 'blur',
                                                                 'diffusion', 'vae', 'vae_18'],
                        help='the type of mask for pixels')

    parser.add_argument('--poison-budget', type=float, default=8, help='the poison budget')

    parser.add_argument('--watermark-budget', type=float, default=8, help='the watermark budget')

    arguments = parser.parse_args()
    arguments.watermark_budget = arguments.watermark_budget / 255
    arguments.poison_budget = arguments.poison_budget / 255

    return arguments

def loss_mix(y, logits):
    criterion = F.cross_entropy
    loss_a = criterion(logits, y[:, 0].long(), reduction="none")
    loss_b = criterion(logits, y[:, 1].long(), reduction="none")
    return ((1 - y[:, 2]) * loss_a + y[:, 2] * loss_b).mean()

def train_epoch(train_loader, model, optimizer, scheduler, epoch, log):
    losses = AverageMeter()
    data_time_meter = AverageMeter()
    train_time_meter = AverageMeter()
    current_lr = optimizer.state_dict()['param_groups'][0]['lr']
    start = time.time()
    if args.aug_type == 'cutmix':
        cutmix = K.RandomCutMixV2(data_keys=["input", "class"])
    elif args.aug_type == 'mixup':
        mixup = K.RandomMixUpV2(data_keys=["input", "class"])

    for i, (data, target, _) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        data_time = time.time() - start
        data_time_meter.update(data_time)

        if args.aug_type == 'cutmix' or args.aug_type == 'mixup':
            if args.aug_type == 'cutmix':
                data, target = cutmix(data, target)
                target = target.squeeze(0)
            elif args.aug_type == 'mixup':
                data, target = mixup(data, target)
            #print(target)
            features = model.train()(data)
            loss = loss_mix(target, features)
        else:
            features = model.train()(data)
            loss = F.cross_entropy(features, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), data.shape[0])

        train_time = time.time() - start
        train_time_meter.update(train_time)
        start = time.time()
    log.info(
        f'TRAINING\t'
        f'Epoch[{epoch}/{args.epochs}]\t'
        f'avg loss = {losses.avg:.4f}\t'
        f'epoch time = {train_time_meter.sum:.2f}\t'
        f'data time = {data_time_meter.sum:.2f}\t'
        f'current lr = {current_lr:.4f}'
    )
    scheduler.step()


def main():
    if args.seed is not None:
        setup_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    save_dir = os.path.join('eval', args.dataset, args.backbone, 'detector', args.experiment)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    log = logger(path=save_dir)
    log.info(str(args))

    try:
        poison = torch.load(args.poison_path, map_location='cpu')
        log.info(f'poison: {torch.max(poison)}')
    except:
        if args.poison_path == 'random':
            torch.manual_seed(1)
            poison = torch.randn([50000, 3, args.poison_size, args.poison_size]).uniform_(-args.poison_budget, args.poison_budget)
            log.info('use random noise')
            log.info(f'poison:{poison}')
        else:
            poison = None
            log.info('no poison founded!')



    if args.post_poisoning:
        if args.mask_type == 'random':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.zeros(1, C, H, W).cuda()

            num_pixels = C * H * W  # Total number of pixels in C * H * W dimensions
            num_unmasked = int(num_pixels - args.wm_length) # Number of unmasked pixels

            # For each batch, randomly select the initial unmasked pixels based on C * H * W
            flat_mask = mask[0].view(-1)  # Flatten the C * H * W part of the mask
            unmasked_indices = torch.randperm(num_pixels)[:num_unmasked]  # Initial random unmasking
            flat_mask[unmasked_indices] = 1  # Set the selected pixels to be unmasked (value = 1)

            # Broadcast the mask to all channels (C), so it has the shape 1 x C x H x W
            mask = flat_mask.view(1, C, H, W).cuda()


        elif args.mask_type == 'fixed':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.zeros(1, C, H, W)
            num_pixels = C * H * W
            rat = np.sqrt(1.0 - (args.wm_length / num_pixels))

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 1

            mask = mask.cuda()

        elif args.mask_type == 'fix-lt':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W  # Total number of pixels in C * H * W dimensions
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-lb':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W  # Total number of pixels in C * H * W dimensions
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(H - int(rat * H), H):
                    for w in range(int(rat * W)):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-rt':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W  # Total number of pixels in C * H * W dimensions
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(int(rat * H)):
                    for w in range(W - int(rat * W), W):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        elif args.mask_type == 'fix-rb':
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W)
            num_pixels = C * H * W  # Total number of pixels in C * H * W dimensions
            rat = np.sqrt(args.wm_length / num_pixels)

            for c in range(C):
                for h in range(H - int(rat * H), H):
                    for w in range(W - int(rat * W), W):
                        mask[0, c, h, w] = 0

            mask = mask.cuda()

        else:
            C, H, W = 3, args.poison_size, args.poison_size
            mask = torch.ones(1, C, H, W).cuda()

        watermark = torch.where(torch.randn(1, C, H, W) < 0, -args.watermark_budget, args.watermark_budget) * (
                torch.ones(1, C, H, W) - mask.cpu())

        log.info(f'the number of unmasked watermark pixels is:{torch.nonzero(watermark).size(0)}')
    else:
        try:
            watermark = torch.load(args.watermark_path)
            log.info(f'the number of unmasked watermark pixels is:{torch.nonzero(watermark).size(0)}')
        except:
            watermark = None
            log.info('no watermark founded!')
            raise ValueError

    s = 1.0
    transform_train = transforms.Compose([transforms.RandomCrop(32, 4),
                                            transforms.RandomHorizontalFlip(p=0.5),
                                            transforms.RandomApply(
                                                [transforms.ColorJitter(0.4 * s, 0.4 * s, 0.4 * s, 0.1 * s)],
                                                p=0.8 * s),
                                            transforms.RandomGrayscale(p=0.2 * s),
                                            transforms.ToTensor(),
                                            Cutout(1, 16)])

    if args.aug_type == 'normal' or args.aug_type == 'cutmix' or args.aug_type == 'mixup':
        transform_eval = transforms.Compose([transforms.RandomCrop(32, 4),
                                                transforms.RandomHorizontalFlip(p=0.5),
                                                transforms.ToTensor()])
    elif args.aug_type == 'cutout':
        transform_eval = transforms.Compose([transforms.RandomCrop(32, 4),
                                                transforms.RandomHorizontalFlip(p=0.5),
                                                transforms.ToTensor(),
                                                Cutout(1, 16)])
    elif args.aug_type == 'jitter':
        transform_eval = transforms.Compose([transforms.RandomCrop(32, 4),
                                              transforms.RandomHorizontalFlip(p=0.5),
                                              transforms.RandomApply(
                                                  [transforms.ColorJitter(0.4 * s, 0.4 * s, 0.4 * s, 0.1 * s)],
                                                  p=0.8 * s),
                                              transforms.ToTensor()])
    elif args.aug_type == 'grayscale':
        transform_eval = transforms.Compose([transforms.RandomCrop(32, 4),
                                              transforms.RandomHorizontalFlip(p=0.5),
                                              transforms.RandomGrayscale(p=0.2 * s),
                                              transforms.ToTensor()])
    elif args.aug_type == 'blur':
        transform_eval = transforms.Compose([transforms.GaussianBlur(kernel_size=3),
                                             transforms.RandomCrop(32, 4),
                                             transforms.RandomHorizontalFlip(p=0.5),
                                             transforms.ToTensor()])
    else:
        transform_eval = transforms.ToTensor()


    train_dataset = CIFAR10WatermarkIndex(root=args.data, train=True, transform=transform_train, poison=None,
                                          watermark=watermark, download=True)


    if args.post_poisoning:
        poison = poison + watermark


    test_dataset = CIFAR10WatermarkIndex(root=args.data, train=True, transform=transform_eval, poison=poison,
                                          watermark=None, download=True)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)

    print(len(train_loader))
    print(len(test_loader))

    im100 = True if args.dataset == 'imagenet100' else False
    if args.backbone == 'resnet18':
        model = resnet18(num_classes=args.num_classes).cuda()
        if args.dataset in ['cifar10', 'cifar100']:
            model.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False).cuda()
            model.maxpool = nn.Identity().cuda()
    elif args.backbone == 'resnet50':
        model = resnet50(num_classes=args.num_classes).cuda()
        if args.dataset in ['cifar10', 'cifar100']:
            model.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False).cuda()
            model.maxpool = nn.Identity().cuda()
    elif args.backbone == 'vgg19':
        model = vgg19(num_classes=args.num_classes, im100=im100).cuda()
    elif args.backbone == 'densenet121':
        model = DenseNet121(num_classes=args.num_classes, im100=im100).cuda()
    elif args.backbone == 'wrn34-10':
        model = WideResNet(num_classes=args.num_classes, im100=im100).cuda()
    elif args.backbone == 'vit':
        patch_size = 16 if im100 else 4
        model = SL_ViT(image_size=args.poison_size, num_classes=args.num_classes, patch_size=patch_size).cuda()
    elif args.backbone == 'linear':
        model = simple_classifier.Linear(n_classes=args.num_classes).cuda()
    elif args.backbone == '2nn':
        model = simple_classifier.two_NN(n_classes=args.num_classes).cuda()
    elif args.backbone == '3nn':
        model = simple_classifier.three_NN(n_classes=args.num_classes).cuda()
    elif args.backbone == 'lenet5':
        model = simple_classifier.LeNet5(num_classes=args.num_classes).cuda()
    else:
        raise AssertionError('model is not defined')

    if args.dataset == 'imagenet100':
        model = nn.DataParallel(model)

    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=1e-4, momentum=0.9)
    else:
        raise AssertionError('optimizer is not defined')

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(step,
                                                args.epochs,
                                                1,
                                                1e-6 / args.lr,
                                                warmup_steps=0)
    )

    start_epoch = 1
    if args.resume:
        checkpoint = torch.load(os.path.join(save_dir, 'model.pt'))
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optim'])
        for i in range(start_epoch - 1):
            scheduler.step()
        log.info(f"RESUME FROM EPOCH {start_epoch-1}")

    if args.aug_type == 'cutmix':
        cutmix = K.RandomCutMixV2(data_keys=["input", "class"])

    if args.get_lr_process:
        val_acc_record = []
        test_acc_record = []

    for epoch in range(start_epoch, args.epochs + 1):
        train_epoch(train_loader, model, optimizer, scheduler, epoch, log)

        if args.get_lr_process:

            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optim': optimizer.state_dict(),
            }, filename=os.path.join(save_dir, 'model.pt'))
            val_auroc = evaluation_detector(train_loader, model)
            test_auroc = evaluation_detector(test_loader, model, wmattacker=args.aug_type)
            val_acc_record.append(val_auroc)
            test_acc_record.append(test_auroc)
            print(val_acc_record, test_acc_record)
            log.info(
                f'val accuracy = {val_auroc:.4f}\t'
                f'test accuracy = {test_auroc:.4f}'
            )

        else:
            if epoch % 10 == 0:
                save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optim': optimizer.state_dict(),
                }, filename=os.path.join(save_dir, 'model.pt'))

                val_auroc = evaluation_detector(train_loader, model)
                if epoch == args.epochs:
                    test_auroc = evaluation_detector(test_loader, model, wmattacker=args.aug_type)
                else:
                    test_auroc = evaluation_detector(test_loader, model)
                log.info(
                    f'val auroc = {val_auroc:.4f}\t'
                    f'test auroc = {test_auroc:.4f}'
                )

    if args.get_lr_process:
        log.info(
            f'val accuracy record = {val_acc_record}\t'
            f'test accuracy record = {test_acc_record}\t'
        )


if __name__ == '__main__':
    args = get_args()
    args.num_classes = 2
    if args.dataset == 'cifar10':
        args.poison_size = 32
    if args.dataset == 'cifar100':
        args.poison_size = 32
    if args.dataset == 'tinyimagenet':
        args.poison_size = 64
    if args.dataset == 'miniimagenet':
        args.poison_size = 84
    if args.dataset == 'imagenet100':
        args.poison_size = 224

    cudnn.benchmark = True
    main()
