import torch
import torch.nn.functional as F
import operator as op
from functools import reduce
import numpy as np
import torchvision.transforms as transforms
import json
import os
import random
from utils.dataload import DeepFakeDatasetPathList
from torch.utils.data import Dataset, DataLoader
import sys
from torchvision.utils import save_image

def ncr(n, r):
    r = min(r, n - r)
    numer = reduce(op.mul, range(n, n - r, -1), 1)
    denom = reduce(op.mul, range(1, r + 1), 1)
    return numer // denom  # or / in Python 2


def quantize_image(image, level=1):
    image = torch.round(image * 255)
    image = image // level
    image = image * level
    image /= 255.0
    return image


def predict(model_list, x, quantize=True):
    if quantize:
        x = quantize_image(x)
    soft_labels = []
    for model in model_list:
        logits = model(x)

        soft_labels.append(logits)
    return soft_labels


def predict_ensemble(model_list, x, device, quantize=True):
    if quantize:
        x = quantize_image(x)
    soft_labels = []
    for model in model_list:
        model.to(device)
        logits = model(x.to(device))

        soft_labels.append(F.softmax(logits, dim=1).data[:, 1].detach().cpu())
        model.to(torch.device("cpu"))
    return soft_labels


def voting_mech(logits, threshold, mode="hard_label"):
    prob_list = [F.softmax(l, dim=-1).data[1] for l in logits]
    if mode == "hard_label":
        prediction_list = [(p > t).double() for p, t in zip(prob_list, threshold)]
        predictions = (torch.sum(torch.stack(prediction_list), dim=0) > 0).float()
    elif mode == "majority_vote":
        prediction_list = [(p > t).double() for p, t in zip(prob_list, threshold)]
        predictions = (torch.sum(torch.stack(prediction_list), dim=0) >= len(prob_list)).float()
    elif mode == "avg_prob":
        avg_logits = torch.log(torch.mean(torch.stack([F.softmax(l, dim=-1) for l in logits]), dim=0))
        predictions = (F.softmax(avg_logits, dim=-1).data[1] > threshold[0]).double()
    return predictions.detach().cpu().numpy().ravel().tolist()

def voting_mech2(prob_list, threshold, mode="hard_label"):
    if mode == "hard_label":
        prediction_list = [(p > t).double() for p, t in zip(prob_list, threshold)]
        predictions = (torch.sum(torch.stack(prediction_list), dim=0) > 0).float()
    elif mode == "majority_vote":
        prediction_list = [(p > t).double() for p, t in zip(prob_list, threshold)]
        predictions = (torch.sum(torch.stack(prediction_list), dim=0) >= len(prob_list)).float()
    elif mode == "avg_prob":
        avg_probs = torch.mean(torch.stack(prob_list).double(), dim=0)
        predictions = (avg_probs > threshold).double()
    return predictions.detach().cpu().numpy().ravel().tolist()

def get_train_loader(data_dir, real_dir, train_input_dim, num_train, split_from_file):
    data_transform = transforms.Compose([
        transforms.Resize(train_input_dim),
        transforms.CenterCrop(train_input_dim),
        transforms.ToTensor()
    ])

    if not split_from_file:
        real_images_train = []
        for path, subdirs, files in os.walk(real_path):
            for name in files:
                if name.endswith('.png') or name.endswith('.jpg'):
                    real_images_train.append((os.path.join(path, name), 0))
        random.shuffle(real_images_train)
        real_images_train = real_images_train[:num_train]
    else:
        real_images_train = [(data_dir + x, 0) for x in
                             json.load(open(data_dir + 'train_real.json', 'r'))]

        random.shuffle(real_images_train)

        real_images_train = real_images_train[:num_train]

    real_image_dataset = DeepFakeDatasetPathList(real_images_train, [], data_transform)

    real_loader = DataLoader(real_image_dataset, 1, shuffle=True) # batch size 1 for sign opt
    return real_loader


class BlackboxModelWrapper(torch.nn.Module):
    # NOTE: must only pass in one image at a time, in 4D
    def __init__(self, model_list, thresholds, device, mode = "hard_label"):
        super(BlackboxModelWrapper, self).__init__()
        self.model_list = model_list
        self.thresholds = thresholds
        self.mode = mode
        self.device = device

    def forward(self, x):
        model_list = self.model_list
        threshold = self.thresholds

        y_probs = predict_ensemble(model_list, x, self.device, quantize = True)
        preds = voting_mech(y_probs, threshold, self.mode)
        preds_tensor = torch.zeros((x.size()[0], 2))
        for i, pred in enumerate(preds):
            preds_tensor[i, int(pred)] = 1.0
        return preds_tensor


    def predict_label(self, x):
        model_list = self.model_list
        threshold = self.thresholds

        y_probs = predict_ensemble(model_list, x, self.device, quantize = True)
        preds = voting_mech(y_probs, threshold, self.mode)
        return torch.tensor(preds).to(x.device)


def fgsm(model_list, image, device, threshold=0.5, loss_type="mean", epsilon=0.031,
         labels=None, keep_attacking=True, inference_models_num=1):
    pert = torch.zeros(image.size()).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    image = image.to(device)
    labels = labels.to(device)
    final_pert_images = torch.clone(image.cpu())

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    y_probs = predict(model_list, image)

    track_preds = voting_mech(y_probs, threshold)

    total_succ_track = [sum(track_preds)]

    pert_image = full_pert_set
    pert_image.requires_grad_()
    loss = []
    probs = []

    for model in model_list:
        outputs = model(pert_image)
        outputs_softmax = F.softmax(outputs, dim=1)

        prob_1 = outputs_softmax[:, 1]
        probs.append(prob_1)

        loss.append(criterion(outputs, torch.ones(pert_image.size()[0], dtype=torch.long).to(device)))

    loss_s = torch.stack(loss)

    if loss_type == "mean":
        torch.mean(loss_s).backward()
    elif loss_type == "min":
        torch.min(loss_s).backward()
    elif loss_type == "max":
        torch.max(loss_s).backward()
    elif loss_type == "min_success":
        min_success_index = np.argmax(
            np.array([np.sum(np.array(voting_mech([prob], [thresh]))) for prob, thresh in zip(probs, threshold)]),
            axis=0)
        loss_s[min_success_index].backward()
    elif loss_type.startswith("pmin_success"):
        kth_index = int(loss_type.split("_")[2])
        attack_index = np.array(
            [np.sum(np.array(voting_mech([prob], [thresh]))) for prob, thresh in zip(probs, threshold)]).argsort()[
                       ::-1][kth_index - 1]
        loss_s[attack_index].backward()
    elif loss_type.startswith("specific"):
        loss_s[int(loss_type.split("_")[1])].backward()

    grad = pert_image.grad

    pert_image = pert_image.detach() + epsilon * grad.sign()

    pert_image = torch.min(torch.max(pert_image, image - epsilon),
                           image + epsilon)
    pert_image = torch.clamp(pert_image, 0.0, 1)

    full_pert_set = pert_image

    y_probs = predict(model_list, full_pert_set)

    avg_preds = []
    for r in range(min(ncr(len(model_list), inference_models_num), 10)):
        inference_model_ids = torch.stack(
            [torch.randperm(len(model_list))[:inference_models_num] for i in range(y_probs[0].size()[0])])

        y_probs_t = torch.stack(y_probs).transpose(0, 1)

        preds = []

        for prob, inf_id in zip(y_probs_t, inference_model_ids):
            preds += voting_mech([x for i, x in enumerate(prob) if i in inf_id],
                                 [x for i, x in enumerate(threshold) if i in inf_id])

        avg_preds.append(preds)

    avg_preds = np.array(avg_preds).mean(axis=0)

    # track_preds = [min(x,y) for i,(x,y) in enumerate(zip(track_preds,avg_preds))]
    total_succ_track.append(sum(avg_preds))

    return total_succ_track


def pgd_single(model, image, device, steps=1500, epsilon=0.031, step_size=0.001, labels=None,
               keep_attacking=True):
    pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    image = image.to(device)
    labels = labels.to(device)
    final_pert_images = torch.clone(image.cpu())

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    y_loss = (F.softmax(model(quantize_image(full_pert_set)), dim=1).data[:, 1]).double().cpu()

    best_attack_loss = torch.ones(image.size()[0])

    for i in range(len(y_loss)):
        if keep_attacking:
            best_attack_loss[i] = y_loss[i]

    for i in range(steps):
        pert_image = full_pert_set
        pert_image.requires_grad_()

        outputs = model(pert_image)

        loss_s = criterion(outputs, torch.ones(pert_image.size()[0], dtype=torch.long).to(device))

        loss_s.backward()

        grad = pert_image.grad

        pert_image = pert_image.detach() + step_size * grad.sign()

        pert_image = torch.min(torch.max(pert_image, image - epsilon),
                               image + epsilon)
        pert_image = torch.clamp(pert_image, 0.0, 1)

        full_pert_set = pert_image

        final_pert_images = full_pert_set.cpu()


        ### COMMENTED TO SPEEDUP
        '''
        y_loss = (F.softmax(model(quantize_image(full_pert_set)), dim=1).data[:, 1]).double().cpu()

        for j in range(len(y_loss)):
            if keep_attacking:
                if labels[j] == 1:
                    if best_attack_loss[j] > y_loss[j]:
                        final_pert_images[j] = full_pert_set[j].cpu()
                        best_attack_loss[j] = y_loss[j]
                elif labels[j] == 0:
                    if best_attack_loss[j] < y_loss[j]:
                        final_pert_images[j] = full_pert_set[j].cpu()
                        best_attack_loss[j] = y_loss[j]
                        
        '''

    final_pert_images = quantize_image(final_pert_images)

    last_pert, last_image = (final_pert_images - image.cpu()), final_pert_images
    return last_pert, last_image


def pgd_max_single_bpda(model, image, device, steps=1500, epsilon=0.031, step_size=0.001, labels=None):
    criterion = torch.nn.CrossEntropyLoss(reduction = 'none')

    with torch.no_grad():
        best_attack = image.detach().clone()
        best_attack_loss = criterion(model(quantize_image(best_attack.to(device))), torch.ones(image.size()[0], dtype=torch.long).to(device))

    pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
    image = image.to(device)
    labels = labels.to(device)
    final_pert_images = torch.clone(image.cpu())

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    for i in range(steps):
        pert_image = full_pert_set
        pert_image.requires_grad_()

        quan_image = quantize_image(pert_image)
        loss_s = criterion(model(quan_image), torch.ones(pert_image.size()[0], dtype=torch.long).to(device))

        best_attack[loss_s.cpu() > best_attack_loss.cpu()] = pert_image.detach().clone().cpu()[loss_s.cpu() > best_attack_loss.cpu()]
        best_attack_loss = torch.max(loss_s, best_attack_loss)

        mean_loss_s = torch.mean(loss_s)

        grad = torch.autograd.grad(mean_loss_s, quan_image)[0]

        pert_image = pert_image.detach() + step_size * grad.sign()

        pert_image = torch.min(torch.max(pert_image, image - epsilon),
                               image + epsilon)
        pert_image = torch.clamp(pert_image, 0.0, 1)

        full_pert_set = pert_image

        final_pert_images = full_pert_set.cpu()

    with torch.no_grad():
        final_loss_s = criterion(model(quantize_image(final_pert_images.to(device))), torch.ones(pert_image.size()[0], dtype=torch.long).to(device))
        best_attack[final_loss_s > best_attack_loss] = final_pert_images.detach().clone()[final_loss_s > best_attack_loss]
        best_attack_loss = torch.max(final_loss_s, best_attack_loss)

    final_pert_images = quantize_image(best_attack)

    last_pert, last_image = (final_pert_images - image.cpu()), final_pert_images
    return last_pert, last_image

def pgd_ensemble(model_list, image, device, threshold=0.5, loss_type="mean", steps=1500, epsilon=0.031, step_size=0.001,
                 labels=None, keep_attacking=True, inference_models_num=1):
    pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    image = image.to(device)
    labels = labels.to(device)
    final_pert_images = torch.clone(image.cpu())

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    y_probs = predict(model_list, image)

    track_preds = voting_mech(y_probs, threshold)

    total_succ_track = [sum(track_preds)]

    for i in range(steps):
        pert_image = full_pert_set
        pert_image.requires_grad_()
        loss = []
        probs = []

        for model in model_list:
            outputs = model(pert_image)
            outputs_softmax = F.softmax(outputs, dim=1)

            prob_1 = outputs_softmax[:, 1]
            probs.append(prob_1)

            loss.append(criterion(outputs, torch.ones(pert_image.size()[0], dtype=torch.long).to(device)))

        loss_s = torch.stack(loss)

        if loss_type == "mean":
            torch.mean(loss_s).backward()
        elif loss_type == "min":
            torch.min(loss_s).backward()
        elif loss_type == "max":
            torch.max(loss_s).backward()
        elif loss_type == "min_success":
            min_success_index = np.argmax(
                np.array([np.sum(np.array(voting_mech([prob], [thresh]))) for prob, thresh in zip(probs, threshold)]),
                axis=0)
            loss_s[min_success_index].backward()
        elif loss_type.startswith("pmin_success"):
            kth_index = int(loss_type.split("_")[2])
            attack_index = np.array(
                [np.sum(np.array(voting_mech([prob], [thresh]))) for prob, thresh in zip(probs, threshold)]).argsort()[
                           ::-1][kth_index]
            loss_s[attack_index].backward()
        elif loss_type.startswith("specific"):
            loss_s[int(loss_type.split("_")[1])].backward()

        grad = pert_image.grad

        pert_image = pert_image.detach() + step_size * grad.sign()

        pert_image = torch.min(torch.max(pert_image, image - epsilon),
                               image + epsilon)
        pert_image = torch.clamp(pert_image, 0.0, 1)

        full_pert_set = pert_image

        y_probs = predict(model_list, full_pert_set)

        avg_preds = []
        for r in range(min(ncr(len(model_list), inference_models_num), 10)):
            inference_model_ids = torch.stack(
                [torch.randperm(len(model_list))[:inference_models_num] for i in range(y_probs[0].size()[0])])

            y_probs_t = torch.stack(y_probs).transpose(0, 1)

            preds = []

            for prob, inf_id in zip(y_probs_t, inference_model_ids):
                preds += voting_mech([x for i, x in enumerate(prob) if i in inf_id],
                                     [x for i, x in enumerate(threshold) if i in inf_id])

            avg_preds.append(preds)

        avg_preds = np.array(avg_preds).mean(axis=0)

        # track_preds = [min(x,y) for i,(x,y) in enumerate(zip(track_preds,avg_preds))]
        total_succ_track.append(sum(avg_preds))

    return total_succ_track



def pgd_max_ensemble_bpda(model_list, model_sampler, image, device, thresholds, loss_type="mean", steps=1500, epsilon=0.031, step_size=0.001,
                 labels=None, vote_type="hard_label", random_restart=True, type="pgd"):  ## model_list is not on gpu
                 ## model_list = [[M11, M12, M13], [M21, M22, M23], [M31, M32, M33], [M41, M42, M43]]
                 ## attack_everything = [[True,True,True,True], [True,True,True,True], [True,True,True,True], [True,True,True,True]]
                 ## EOT = [[False, False, True, False], [False, True, False, False], [False, False, False, True]]
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
                 
    confidence_score_track, loss_track = [], []
    grad_magnitude_track = []

    with torch.no_grad():
        best_attack = image.detach().clone()

        confidence_scores = []
        attack_loss = []
        for model in [m for ensemble in model_list for m in ensemble]:
            model.to(device)
            logits = model(quantize_image(best_attack.to(device)))
            confidence_scores.append(F.softmax(logits, dim=1).data[:, 1].detach().cpu())
            attack_loss.append(criterion(logits,
                                     torch.ones(image.size()[0], dtype=torch.long).to(device)).detach().cpu())
            model.to(torch.device("cpu"))
        best_attack_loss = torch.stack(attack_loss).mean(axis=0)
        combined_thresholds = [m for ensemble in thresholds for m in ensemble]
        best_attack_predictions = torch.tensor(voting_mech2(confidence_scores, combined_thresholds))

    confidence_score_track.append(torch.stack(confidence_scores))
    loss_track.append(torch.stack(attack_loss))


    best_attack = best_attack.cpu()

    total_succ_track = [sum(best_attack_predictions).item()]

    if random_restart:
        pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
    else:
        pert = torch.FloatTensor(*image.shape).uniform_(0, 0).to(device)

    image = image.to(device)
    labels = labels.to(device)

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    grad_alignment = []

    for i in range(steps):
        if type == "fgsm" and i % 2 == 0:
            if random_restart:
                pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
            else:
                pert = torch.FloatTensor(*image.shape).uniform_(0, 0).to(device)
            full_pert_set = torch.clamp(image + pert, 0.0, 1.0)
        pert_image = full_pert_set
        pert_image.requires_grad_()

        quan_image = quantize_image(pert_image)
        #quan_image = pert_image

        confidence_scores = []
        attack_loss = []
        grads = []
        for model in [m for ensemble in model_list for m in ensemble]:
            model.to(device)
            logits = model(quan_image)
            confidence_scores.append(F.softmax(logits, dim=1).data[:, 1].detach().cpu())
            loss = criterion(logits,
                               torch.ones(image.size()[0], dtype=torch.long).to(device))

            mean_loss = torch.mean(loss)

            #mean_loss.backward()
            #grad = quan_image.grad
            grad = torch.autograd.grad(mean_loss, quan_image)[0]
            #print(grad)

            grads.append(grad.detach())

            attack_loss.append(loss.detach().cpu())

            model.to(torch.device("cpu"))

        confidence_score_track.append(torch.stack(confidence_scores))
        loss_track.append(torch.stack(attack_loss))

        combined_attack_loss = torch.stack(attack_loss).mean(axis=0)
        combined_thresholds = [m for ens_index, ensemble in enumerate(thresholds) for m in ensemble if model_sampler[i][ens_index]]

        prediction_aggregator = [0]*image.size()[0]
        ensemble_size = len(model_list[0])
        for ens_index in range(len(model_list)):
            attack_predictions = voting_mech2(confidence_scores[ens_index*ensemble_size:(ens_index+1)*ensemble_size], thresholds[ens_index])
            prediction_aggregator = [x+y for x,y in zip(prediction_aggregator,attack_predictions)]

        combined_attack_predictions = [x/len(model_list) for x in prediction_aggregator]
        #combined_attack_predictions = voting_mech(confidence_scores, combined_thresholds)
        combined_attack_predictions = torch.tensor(combined_attack_predictions)

        best_attack_condition = torch.zeros(best_attack.size()[0]).bool()

        best_attack_condition = torch.logical_or(best_attack_condition, (combined_attack_predictions < best_attack_predictions))
        best_attack_condition = torch.logical_or(best_attack_condition, torch.logical_and(combined_attack_loss > best_attack_loss, combined_attack_predictions == best_attack_predictions))

        best_attack[best_attack_condition] = pert_image.detach().clone().cpu()[best_attack_condition]
        best_attack_predictions[best_attack_condition] = combined_attack_predictions[best_attack_condition]
        best_attack_loss[best_attack_condition] = combined_attack_loss[best_attack_condition]

        total_succ_track.append(sum(best_attack_predictions).item())
        #print(sum(best_attack_predictions))

        #print(best_attack_condition)

        models_to_attack = []
        for v in model_sampler[i]:
            models_to_attack += [v]*ensemble_size
        grads_to_attack = [g for k,g in enumerate(grads) if models_to_attack[k]]
        filtered_attack_losses = [g for k, g in enumerate(attack_loss) if models_to_attack[k]]

        combined_grad = 0
        if loss_type=="mean2":
            grad_magnitude_track.append([grad.norm().item() for grad in grads_to_attack])
            combined_grad = torch.stack(grads_to_attack).mean(dim=0)
        elif loss_type=="mean_norm":
            grad_magnitude_track.append([grad.norm().item() for grad in grads_to_attack])
            grads_to_attack = [grad / grad.norm() for grad in grads_to_attack]
            combined_grad = torch.stack(grads_to_attack).mean(dim=0)
        elif loss_type=="mean_norm2":
            grad_magnitude_track.append([grad.norm().item() for grad in grads_to_attack])
            grads_to_attack_new = []
            for grad in grads_to_attack:
                batch_grad = []
                for img in grad:
                    batch_grad.append(img/img.norm())
                grads_to_attack_new.append(torch.stack(batch_grad))
            combined_grad = torch.stack(grads_to_attack_new).mean(dim=0)
        elif loss_type=="min":
            grad_magnitude_track.append([grad.norm().item() for grad in grads_to_attack])
            #print(torch.stack(attack_loss).size(), torch.stack(grads).size())
            min_indices = torch.argmin(torch.stack(filtered_attack_losses), dim=0)

            stacked_grads = torch.stack(grads_to_attack)
            selected_grads = []
            for r,min_index in enumerate(min_indices):
                selected_grads.append(stacked_grads[min_index,r])
            combined_grad = torch.stack(selected_grads)
            #print(combined_grad.size())

        cos_sim = torch.nn.CosineSimilarity(dim=0)
        grad_alignment.append([np.array([cos_sim(cg_img.flatten(), g_img.flatten()).detach().cpu().item() for cg_img,g_img in zip(combined_grad,gradi)]).mean() for gradi in grads])

        pert_image = pert_image.detach() + step_size * combined_grad.sign()

        pert_image = torch.min(torch.max(pert_image, image - epsilon),
                               image + epsilon)
        pert_image = torch.clamp(pert_image, 0.0, 1)

        full_pert_set = pert_image

    '''
    with torch.no_grad():
        confidence_scores = []
        attack_loss = []
        for model in model_list:
            model.to(device)
            logits = model(quantize_image(best_attack.to(device)))
            confidence_scores.append(F.softmax(logits, dim=1).data[:, 1])
            attack_loss.append(criterion(logits),
                               torch.ones(image.size()[0], dtype=torch.long))
            model.to(torch.device("cpu"))
        combined_attack_loss = torch.stack(attack_loss).mean(axis=0)
        combined_attack_predictions = voting_mech(confidence_scores, thresholds)

        best_attack_condition = torch.zeros(best_attack.size()).bool()

        best_attack_condition = best_attack_condition or (
                    combined_attack_predictions.cpu() < best_attack_predictions.cpu())
        best_attack_condition = best_attack_condition or ((combined_attack_loss.cpu() > best_attack_loss.cpu()) and (
                    combined_attack_predictions.cpu() == best_attack_predictions.cpu()))

        best_attack[best_attack_condition] = pert_image.detach().clone().cpu()[best_attack_condition]
        best_attack_prediction[best_attack_condition] = combined_attack_predictions[best_attack_condition]
        best_attack_loss[best_attack_condition] = combined_attack_loss[best_attack_condition]
    '''

    final_pert_images = quantize_image(best_attack)

    last_pert, last_image = (final_pert_images - image.cpu()), final_pert_images
    return last_pert, last_image, total_succ_track, confidence_score_track, loss_track, grad_magnitude_track, grad_alignment

def pgd_max_ensemble_bpda_simple(model_list, model_sampler, image, device, thresholds, loss_type="mean",
                                           steps=1500, epsilon=0.031, step_size=0.001,
                                           labels=None, vote_type="hard_label", random_restart=True, type="pgd"):  ## model_list is not on gpu

    ## model_list = [[M11, M12, M13], [M21, M22, M23], [M31, M32, M33], [M41, M42, M43]]
    ## attack_everything = [[True,True,True,True], [True,True,True,True], [True,True,True,True], [True,True,True,True]]
    ## EOT = [[False, False, True, False], [False, True, False, False], [False, False, False, True]]
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    confidence_score_track, loss_track = [], []
    grad_magnitude_track = []


    with torch.no_grad():
        best_attack = image.detach().clone()

        confidence_scores = []
        attack_loss = []
        softmax_list = []
        for model in [m for ensemble in model_list for m in ensemble]:
            model.to(device)
            logits = model(quantize_image(best_attack.to(device)))
            softmax_list.append(F.softmax(logits, dim=-1))
            confidence_scores.append(F.softmax(logits, dim=1).data[:, 1].detach().cpu())
            attack_loss.append(criterion(logits,
                                         torch.ones(image.size()[0], dtype=torch.long).to(device)).detach().cpu())

        final_logits = torch.log(torch.stack(softmax_list).mean(dim=0))
        loss = criterion(final_logits, torch.ones(image.size()[0], dtype=torch.long).to(device))

        best_attack_loss = loss.detach().cpu()#torch.stack(attack_loss).mean(axis=0)
        combined_thresholds = [m for ensemble in thresholds for m in ensemble]
        best_attack_predictions = torch.tensor(F.softmax(final_logits, dim=1).data[:, 1].detach().cpu() > combined_thresholds[0])

    confidence_score_track.append(torch.stack(confidence_scores))
    loss_track.append(torch.stack(attack_loss))

    best_attack = best_attack.cpu()

    total_succ_track = [sum(best_attack_predictions).item()]

    if random_restart:
        pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
    else:
        pert = torch.FloatTensor(*image.shape).uniform_(0, 0).to(device)
    image = image.to(device)
    labels = labels.to(device)

    grad_alignment = []

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    for i in range(steps):
        if type=="fgsm" and i%2==0:
            if random_restart:
                pert = torch.FloatTensor(*image.shape).uniform_(-epsilon, epsilon).to(device)
            else:
                pert = torch.FloatTensor(*image.shape).uniform_(0, 0).to(device)
            full_pert_set = torch.clamp(image + pert, 0.0, 1.0)
        pert_image = full_pert_set
        pert_image.requires_grad_()

        quan_image = quantize_image(pert_image)
        # quan_image = pert_image

        confidence_scores = []
        attack_loss = []
        grads = []
        softmax_list = []
        for model in [m for ensemble in model_list for m in ensemble]:
            model.to(device)
            logits = model(quan_image)
            confidence_scores.append(F.softmax(logits, dim=1).data[:, 1].detach().cpu())
            softmax_list.append(F.softmax(logits, dim=-1))

            loss = criterion(logits,
                             torch.ones(image.size()[0], dtype=torch.long).to(device))

            gradi = torch.autograd.grad(torch.mean(loss), quan_image, retain_graph=True)[0]

            # print(grad)

            grads.append(gradi.detach())
            attack_loss.append(loss.detach().cpu())

        final_logits = torch.log(torch.stack(softmax_list).mean(dim=0))
        loss2 = criterion(final_logits, torch.ones(image.size()[0], dtype=torch.long).to(device))

        grad = torch.autograd.grad(loss2.mean(), quan_image)[0]

        confidence_score_track.append(torch.stack(confidence_scores))
        loss_track.append(torch.stack(attack_loss))

        combined_attack_loss = loss.detach().cpu()#torch.stack(attack_loss).mean(axis=0)
        combined_thresholds = [m for ens_index, ensemble in enumerate(thresholds) for m in ensemble if
                               model_sampler[i][ens_index]]


        combined_attack_predictions = torch.tensor(F.softmax(final_logits, dim=1).data[:, 1].detach().cpu() > combined_thresholds[0])
        # combined_attack_predictions = voting_mech(confidence_scores, combined_thresholds)

        best_attack_condition = torch.zeros(best_attack.size()[0]).bool()

        best_attack_condition = torch.logical_or(best_attack_condition,
                                                 (combined_attack_predictions < best_attack_predictions))
        best_attack_condition = torch.logical_or(best_attack_condition,
                                                 torch.logical_and(combined_attack_loss > best_attack_loss,
                                                                   combined_attack_predictions == best_attack_predictions))

        best_attack[best_attack_condition] = pert_image.detach().clone().cpu()[best_attack_condition]
        best_attack_predictions[best_attack_condition] = combined_attack_predictions[best_attack_condition]

        best_attack_loss[best_attack_condition] = combined_attack_loss[best_attack_condition]

        total_succ_track.append(sum(best_attack_predictions).item())
        # print(sum(best_attack_predictions))

        # print(best_attack_condition)

        cos_sim = torch.nn.CosineSimilarity(dim=0)
        grad_alignment.append([np.array([cos_sim(cg_img.flatten(), g_img.flatten()).detach().cpu().item() for cg_img,g_img in zip(grad,gradi)]).mean() for gradi in grads])

        pert_image = pert_image.detach() + step_size * grad.sign()

        pert_image = torch.min(torch.max(pert_image, image - epsilon),
                               image + epsilon)
        pert_image = torch.clamp(pert_image, 0.0, 1)

        full_pert_set = pert_image

    '''
    with torch.no_grad():
        confidence_scores = []
        attack_loss = []
        for model in model_list:
            model.to(device)
            logits = model(quantize_image(best_attack.to(device)))
            confidence_scores.append(F.softmax(logits, dim=1).data[:, 1])
            attack_loss.append(criterion(logits),
                               torch.ones(image.size()[0], dtype=torch.long))
            model.to(torch.device("cpu"))
        combined_attack_loss = torch.stack(attack_loss).mean(axis=0)
        combined_attack_predictions = voting_mech(confidence_scores, thresholds)
    
        best_attack_condition = torch.zeros(best_attack.size()).bool()
    
        best_attack_condition = best_attack_condition or (
                    combined_attack_predictions.cpu() < best_attack_predictions.cpu())
        best_attack_condition = best_attack_condition or ((combined_attack_loss.cpu() > best_attack_loss.cpu()) and (
                    combined_attack_predictions.cpu() == best_attack_predictions.cpu()))
    
        best_attack[best_attack_condition] = pert_image.detach().clone().cpu()[best_attack_condition]
        best_attack_prediction[best_attack_condition] = combined_attack_predictions[best_attack_condition]
        best_attack_loss[best_attack_condition] = combined_attack_loss[best_attack_condition]
    '''

    final_pert_images = quantize_image(best_attack)

    last_pert, last_image = (final_pert_images - image.cpu()), final_pert_images
    return last_pert, last_image, total_succ_track, confidence_score_track, loss_track, grad_magnitude_track, grad_alignment

def run_clean_predictions(models, real_image_loader, fake_image_loader, device):
    real_probs, real_actual = [[] for x in range(len(models))], []
    for i, (x, y) in enumerate(real_image_loader):
        x = x.to(device)
        y = y.to(device)

        y_probs = predict(models, x)

        for k, p in enumerate(y_probs): real_probs[k] += p.detach().cpu().numpy().tolist()

        y_true = y.detach().cpu().numpy().tolist()

        real_actual += y_true

        del x, y

    fake_probs, fake_actual = [[] for x in range(len(models))], []
    for i, (x, y) in enumerate(fake_image_loader):
        x = x.to(device)
        y = y.to(device)

        y_probs = predict(models, x)

        for k, p in enumerate(y_probs): fake_probs[k] += p.detach().cpu().numpy().tolist()

        y_true = y.detach().cpu().numpy().tolist()

        fake_actual += y_true

        del x, y

    return real_probs, real_actual, fake_probs, fake_actual


def hsja(ensemble_list, image, device, threshold_values=[0.5], labels=None, data_dir=None, train_dir=None,
         train_input_dim=128, num_train=50000, split_from_file=True, norm = 'l2'):
    model_list = [m for ensemble in ensemble_list for m in ensemble]
    combined_thresholds = [m for ensemble in threshold_values for m in ensemble]
    model_wrapper = BlackboxModelWrapper(model_list, combined_thresholds, device)

    image = image.to(device)
    labels = labels.to(device)

    train_set = None if train_dir is None and data_dir is None or train_dir == "" and data_dir == "" else get_train_loader(data_dir, train_dir, train_input_dim, num_train, split_from_file)

    from attacks.hsja import HSJAttack
    attack = HSJAttack(model_wrapper, norm = norm, batch_size = image.size()[0])
    one_hot_labels = torch.zeros((image.size()[0], 2))
    for i in range(labels.size()[0]):
        one_hot_labels[i, labels[i].item()] = 1
    x_tar = None
    best_distortion = torch.ones((image.size()[0])) * 1000000
    best_adv = torch.zeros(image.size())
    for i, (xi, yi) in enumerate(train_set):
        x_tar = xi.to(image.device)#.unsqueeze(0)
        x_tar = x_tar.repeat(image.size()[0], 1, 1, 1)
        adv = torch.from_numpy(attack.perturb(image.detach().cpu().numpy(), one_hot_labels.cpu().numpy(), x_tar.detach().cpu().numpy()))

        flat_diff = (adv - image.cpu()).view(image.size()[0], -1)
        if norm == 'l2':
            distortion = torch.linalg.norm(flat_diff, dim = 1)
        elif norm == 'linf':
            distortion = torch.linalg.norm(flat_diff, dim = 1, ord=np.inf)

        preds = model_wrapper.predict_label(adv)
        distortion[preds == labels.cpu()] = 1000000

        best_adv[distortion < best_distortion] = adv.clone()[distortion < best_distortion]
        best_distortion[distortion < best_distortion] = distortion[distortion < best_distortion]

    image = best_adv.clone()

    preds = model_wrapper.predict_label(image)
    try:
        succ_track = sum(model_wrapper.predict_label(image))
    except:
        succ_track = model_wrapper.predict_label(image)

    return succ_track, best_distortion, image, preds
