import os
import pdb
import time 
import pickle
import random
import shutil
import argparse
import numpy as np  
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from resnet import resnet18

from train_utils import *
from dataloader_ffcv import create_dataloader
from write_dataset_ffcv import write_dataset

from tqdm import tqdm
import time
from plot_utils import plot_SPC
from torch.cuda.amp import GradScaler, autocast
from torchvision import transforms
from PIL import Image

def l_SPC(images, model, scales, device):


    l_SPC = torch.zeros(len(images)).to(device)
    with torch.no_grad():
        with autocast():
            base_pred = torch.argmax(model(images), dim=1) # batchsize x 10
            for scale in scales:
                scale_pred = torch.argmax(model(torch.clamp(scale*images,0,1)), dim=1)

                #import pdb;pdb.set_trace()


                l_SPC = l_SPC + (scale_pred == base_pred)
                #else:
                #    l_SPC = l_SPC - 1


    return l_SPC / len(scales)




def SF_loss(output, output_scale):
    
    kl_loss = nn.KLDivLoss(reduction="none", log_target=True)  
    ## kl_loss(logQ, logP) = P (logP - logQ) 
    
    P = F.softmax(output, dim=1)
    Q = F.softmax(output_scale, dim=1)
    M = (P+Q)/2
    
    log_P = torch.log(P)
    log_Q = torch.log(Q)
    log_M = torch.log(M)
    
    loss = torch.sum( (kl_loss(log_Q, log_P))  ,dim=1) 
        
    return loss


def main(args,  device):
    
     
    ## clean data and poisoned model
    
    train_clean_loader, test_clean_loader, _ = create_dataloader(args, args.batch_size, '', device, partition='None')
    model_path = f'Results/{args.dataset}/Badnet/Poisonratio_0.1/{args.arch}/Trial 1'
    model = torch.load(f'{model_path}/model.pt')
    model.to(device)    
    for param in model.parameters(): param.requires_grad = False 
    
    mask = torch.load(f'{model_path}/Bilevel/0.1/mask_4.pt') 
    mask.requires_grad = False 
    
    tau = 0.1
    
    #### Create trainable trigger initialized as the badnet trigger 
    trans_trigger = transforms.Compose([transforms.Resize((args.patch_size, args.patch_size)), transforms.ToTensor(), lambda x: x * 255])
    trigger_data = Image.open("data/triggers/htbd.png").convert("RGB")
    trigger_data = trans_trigger(trigger_data).to(device)
    trigger_data = trigger_data/255.
    
    #import pdb;pdb.set_trace()
    ## Add the trigger
    image_size = 32
    start_x = image_size - args.patch_size - 3
    start_y = image_size - args.patch_size - 3

    trigger = torch.zeros((3,32,32), device= device)
    trigger[:,start_x: start_x + args.patch_size, start_y: start_y + args.patch_size] = trigger_data

    #trigger = torch.tensor(np.transpose(trigger_data.numpy(), (1, 2, 0))).to(device)
    #trigger = torch.stack([trigger]*100)
    trigger.requires_grad = True
    

    scales = [2,3,4,5,6,7,8,9,10,11,12]
    
    
    ## criterion and optimizer
    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    optimizer = torch.optim.SGD([trigger], args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
    scaler = GradScaler()
    train_no = 50000
    
    test_clean_SPC = validate_SPC(test_clean_loader, model, scales, args.patch_size, args, device, trigger=None)
    test_trigger_SPC = validate_SPC(test_clean_loader, model, scales, args.patch_size, args, device, trigger)
    test_ASR = validate_ASR(test_clean_loader, model, scales, args.patch_size, args, device, trigger)
    
    test_clean_MSPC = validate_MSPC(test_clean_loader, model,  mask, tau, scales, args.patch_size, args, device, trigger=None)
    test_trigger_MSPC = validate_MSPC(test_clean_loader, model,  mask, tau, scales, args.patch_size, args, device, trigger)

    print(f'test_trigger_SPC: {test_trigger_SPC}| Clean SPC : {test_clean_SPC} | ASR : {test_ASR}')
    print(f'test_trigger_MSPC: {test_trigger_MSPC}| Clean MSPC : {test_clean_MSPC} | ASR : {test_ASR}')

    ## Optimize trigger
    for epoch in range(args.epochs):

        trigger = train(train_clean_loader, model, trigger, mask, tau, scales, args.patch_size, optimizer,scheduler, scaler,  epoch, args, device)
        test_trigger_SPC = validate_SPC(test_clean_loader, model, scales, args.patch_size, args, device, trigger)
        test_trigger_MSPC = validate_MSPC(test_clean_loader, model,  mask, tau, scales, args.patch_size, args, device, trigger)
        test_ASR = validate_ASR(test_clean_loader, model, scales, args.patch_size, args, device, trigger)
        
        print(f'test_trigger_SPC: {test_trigger_SPC}| test_trigger_MSPC: {test_trigger_MSPC}| ASR : {test_ASR}')
        scheduler.step()
    

    torch.save(trigger, "data/triggers/Badnet_adaptive4_trigger.pt")




def train(train_loader, model, trigger, optimized_mask, tau, scales, patch_size, optimizer, scheduler, scaler, epoch, args, device):
    
    losses = AverageMeter()
    top1 = AverageMeter()


    iterator = tqdm(enumerate(train_loader), total=len(train_loader))

    for i, (image, target, _, _) in iterator:
        
        l_rec = 0
        batch_size = image.shape[0]
        image_size = 32
        start_x = image_size - args.patch_size - 3
        start_y = image_size - args.patch_size - 3
        mask = torch.ones(image.shape, requires_grad=False, device=device)
        mask[:,:,start_x: start_x + patch_size, start_y: start_y + patch_size] = 0         
        
        image = image/255.
     
    
        optimizer.zero_grad(set_to_none=True)
        
        for scale in scales:
            
            triggered_image = image*mask + (1-mask)*torch.stack([trigger]*image.shape[0])

            masked_poison_images = (triggered_image-tau)*optimized_mask
            
            
            with autocast():
                
                output = model(triggered_image)
                output_scale = model(torch.clamp(scale*masked_poison_images,0,1))
                
                loss_fdiv = (1/batch_size)*torch.sum(SF_loss(output, output_scale))
                loss_fdiv = -loss_fdiv / len(scales)
                
                scaler.scale(loss_fdiv).backward()
                #scaler.loss(loss_fdiv).backward()
                
                with torch.no_grad():
                    l_rec += loss_fdiv
            
            
        #optimizer.step()
        
        scaler.step(optimizer)
        scaler.update()
        
        
        loss = l_rec.float()
  
        losses.update(loss.item(), image.size(0))

        iterator.set_description(f"Epoch {epoch} | LR {optimizer.param_groups[0]['lr']:.2f}") ## FIND LR!!!!!
        iterator.set_postfix(loss=loss.item())
        iterator.refresh()

    return trigger
        
    # save_metrics(losses.avg, pathname, "Training Loss")

 


def validate_SPC(val_loader, model, scales, patch_size, args, device, trigger):
 
    losses = AverageMeter()
  

    for i, (image, target, _, _) in enumerate(val_loader):
        
        if trigger==None:
            triggered_image = image / 255.
        else:
            image_size = image.shape[-1]
            start_x = image_size - args.patch_size - 3
            start_y = image_size - args.patch_size - 3
            
            mask = torch.ones(image.shape, requires_grad=False, device=device)
            mask[:,:,start_x: start_x + patch_size, start_y: start_y + patch_size] = 0


            image = image/255.
            #import pdb;pdb.set_trace()
            triggered_image = image*mask + (1-mask)*torch.stack([trigger]*image.shape[0])


        # compute output
        with torch.no_grad():
            with autocast():
              loss = torch.mean(l_SPC(triggered_image, model, scales, device))

        loss = loss.float()
        # measure accuracy and record loss
        losses.update(loss.item(), image.size(0))


    return losses.avg




def validate_MSPC(val_loader, model,  optimized_mask, tau, scales, patch_size, args, device, trigger):
 
    losses = AverageMeter()
  

    for i, (image, target, _, _) in enumerate(val_loader):
        
        
        if trigger==None:
            images = image /255.
        else:
            image_size = image.shape[-1]
            start_x = image_size - args.patch_size - 3
            start_y = image_size - args.patch_size - 3
            
            mask = torch.ones(image.shape, requires_grad=False, device=device)
            mask[:,:,start_x: start_x + patch_size, start_y: start_y + patch_size] = 0


            image = image/255.
            images = image*mask + (1-mask)*torch.stack([trigger]*image.shape[0])

        images_mask = (images-tau)*optimized_mask
        
        l_SPC = torch.zeros(len(images)).to(device)
        with torch.no_grad():
            with autocast():
                base_pred = torch.argmax(model(images), dim=1) # batchsize x 10
                # base_pred = torch.argmax(model(images+torch.rand(images.shape).to(device)), dim=1)
                for scale in scales:
                    scale_pred = torch.argmax(model(torch.clamp(scale*images_mask,0,1)), dim=1)         
                    l_SPC = l_SPC + (2*(scale_pred == base_pred) - 1)
        
        loss = torch.mean(l_SPC/len(scales))
        loss = loss.float()
        # measure accuracy and record loss
        losses.update(loss.item(), image.size(0))


    return losses.avg




def validate_ASR(val_loader, model, scales, patch_size, args, device, trigger):
    
    top1 = AverageMeter()
    model.eval()

    for i, (image, var, _, _) in enumerate(val_loader):
        
        image_size = image.shape[-1]
        start_x = image_size - args.patch_size - 3
        start_y = image_size - args.patch_size - 3
        
        mask = torch.ones(image.shape, requires_grad=False, device=device)
        mask[:,:,start_x: start_x + patch_size, start_y: start_y + patch_size] = 0


        image = image/255.
        triggered_image = image*mask + (1-mask)*torch.stack([trigger]*image.shape[0])


        target = torch.ones(var.shape, device=device)

        # compute output
        with torch.no_grad():
            with autocast():
              output = model(triggered_image)

        output = output.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        top1.update(prec1.item(), image.size(0))

    return top1.avg



if __name__ == '__main__':
    
    
    parser = argparse.ArgumentParser()

    ##################################### general setting #################################################
    parser.add_argument('--data', type=str, default='../data0', help='location of the data corpus')
    parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
    parser.add_argument('--arch', type=str, default='res18', help='model architecture')

    ##################################### training setting #################################################
    parser.add_argument('--batch_size', type=int, default=100, help='batch size')
    parser.add_argument('--poison_ratio', default=0.0, type=float, help='Poison Ratio')
    parser.add_argument('--lr', default=1.0, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
    parser.add_argument('--epochs', default=2, type=int, help='number of total epochs to run')
    parser.add_argument('--decreasing_lr', default='50,80', help='decreasing strategy')
    parser.add_argument('--attack', default='Badnet', type=str, help='Give attack name')
    parser.add_argument('--target', default=1, type=int, help= 'Target label')
    parser.add_argument('--save_samples', type=str, default='False', help='Give attack name')
    parser.add_argument('--patch_size', default=5, type=int, help= 'Target label')
    
    opt = parser.parse_args()
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(device)
        
    main(opt, device)







