#https://github.com/bboylyg/ABL/blob/main/backdoor_isolation.py#L194
#import sys
#sys.path.append('../')

from dataloader_ffcv import create_dataloader
from train_utils import *


from tqdm import tqdm
import argparse

from pathlib import Path
import torch.nn as nn
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import roc_curve, auc, precision_recall_curve

from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F


def strip_entropy(trainimgs, blendingimgs, model,  device, n=100, alpha=0.5, beta=0.5):
    # Return Entropy H
    idx = np.random.randint(0, blendingimgs.shape[0], size=n)
    H = []
    for img in tqdm(trainimgs):
        x = torch.stack([img] * n).to(device)
        for i in range(n):
            x_0 = x[i]
            x_1 = blendingimgs[idx[i]].to(device)
            x_sum = alpha * x_0 + beta * x_1
            x_2 = torch.clamp(x_sum, 0, 1) 
            x[i] = x_2

        with autocast():
            logits = model(x)
        prob = F.softmax(logits.detach(), dim=1)
        H_i = - torch.sum(prob * torch.log(prob), dim=1)
        H.append(H_i.mean().item())
    return torch.tensor(H).detach().cpu()
    


def compute_strip(args, device):
    

    
    if args.dataset == 'cifar10':
        train_no = 50000
    elif args.dataset == 'imagenet200':
        train_no = 100000
    elif args.dataset == 'tinyimagenet':
        train_no = 100000
        
    model_path = f'Results/{args.dataset}/{args.attack}/Poisonratio_{args.poison_ratio}/{args.arch}/Trial {args.trialno}'
    model = torch.load(f'{model_path}/model.pt')
    model.to(device)    
     
    pathname = f'{model_path}/STRIP'
    Path(pathname).mkdir(parents=True)

    model.eval()

    batch_size = 200

    train_data_loader, testcleanloader, _ = create_dataloader(args, batch_size, '', device, partition='None', seq=True)
    
    poison_label_full = torch.zeros(train_no)

    
    train_data = torch.zeros(train_no, 3, 32, 32)
    for ix, (img, _, _, poisonlabel) in enumerate(train_data_loader):
        img = img/255.
        train_data[ix*batch_size:ix*batch_size + batch_size] = img
        poison_label_full[ix*batch_size:ix*batch_size + batch_size] = poisonlabel


    blendingimgs = torch.zeros(len(testcleanloader)*batch_size, 3, 32, 32) # blendingimgs are held-out clean images from test dataset
    for ix, (img, gtlabs, _, poisonlab) in enumerate(testcleanloader):
        
        img = img/255.
        blendingimgs[ix*batch_size:ix*batch_size + batch_size] = img
        
    
    #import pdb;pdb.set_trace()

    # STRIP_label_full = strip_entropy(train_data, train_data, model, device).numpy()
    #num_sample = 1000 # nunber of samples to be detected
    #STRIP_label_full = strip_entropy(train_data[:num_sample], blendingimgs, model,  device)


    STRIP_label_full = strip_entropy(train_data, train_data, model,  device)
    STRIP_label_full = 1 -  (STRIP_label_full.detach().cpu()).numpy() #min_max_normalization

    roc_auc = roc_auc_score(poison_label_full, torch.tensor(STRIP_label_full))
    print(roc_auc)
    

    torch.save(poison_label_full, f'{pathname}/poisonlab_true.pt')
    torch.save(STRIP_label_full, f'{pathname}/STRIP_pred.pt')
    with open(f'{pathname}/AUROC_STRIP', 'w') as f:
        json.dump(roc_auc, f, indent=2)

        
def min_max_normalization(x):
    x_min = torch.min(x)
    x_max = torch.max(x)
    norm = (x - x_min) / (x_max - x_min)
    norm = torch.clamp(norm, 0, 1)
    return norm

def main(opt, device):
    compute_strip(opt, device)



if (__name__ == '__main__'):
    
    parser = argparse.ArgumentParser()

    # various path
    parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
    parser.add_argument('--arch', type=str, default='res18', help='model architecture')


    parser.add_argument('--poison_ratio', default=0.1, type=float, help='Poison Ratio')
    parser.add_argument('--attack', type=str, help='Give attack name')
    parser.add_argument('--save_samples', type=str, default='False', help='Give attack name') 
    parser.add_argument('--trialno',  type=int)

    opt = parser.parse_args()
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    main(opt, device)

