import json
import random
from typing import List

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import get_dataset, get_dataset_mean_var
from model_mimic_attack_utils import image_clamp_min_max_val
from nn_model_manager import NNModelManager
from test_code import return_model_by_name
from utils import root_dir


class BayesianSubstituteModel(nn.Module):
    def __init__(self, base_model):
        super(BayesianSubstituteModel, self).__init__()
        self.base_model = base_model
        self.dropout_layer = nn.Dropout(p=0.2)

    def forward(self, x, mc_dropout=False, mc_samples=10):
        if mc_dropout:
            outputs = torch.stack([self._forward_once(x) for _ in range(mc_samples)], dim=0)
            return outputs.mean(dim=0)
        else:
            return self._forward_once(x)

    def _forward_once(self, x):
        x = self.dropout_layer(x)
        x = self.base_model(x)
        return x


def get_clamp_bounds(mean_dataset, std_dataset):
    min_val = torch.tensor([-m / s for m, s in zip(mean_dataset, std_dataset)]).view(3, 1, 1)
    max_val = torch.tensor([(1 - m) / s for m, s in zip(mean_dataset, std_dataset)]).view(3, 1, 1)
    return min_val, max_val


def bayesian_attack(
        substitute_model, black_box_model, images, labels,
        mean_dataset: List[float], std_dataset: List[float],
        num_iterations=50, learning_rate=0.01, mc_samples=10
):
    device = images.device
    min_pixel_val, max_pixel_val = image_clamp_min_max_val(mean_dataset, std_dataset)
    images = images.to(device)
    labels = labels.to(device)
    perturbed_images = images.clone().detach().requires_grad_(False).to(device)
    attack_ss_iter = 0

    for i in tqdm(range(num_iterations)):
        perturbed_images.requires_grad_(True)
        outputs = substitute_model(perturbed_images, mc_dropout=True, mc_samples=mc_samples)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        perturbed_images = perturbed_images + learning_rate * perturbed_images.grad.sign()
        perturbed_images = torch.clamp(perturbed_images, min_pixel_val, max_pixel_val)
        perturbed_images = perturbed_images.detach()

        with torch.no_grad():
            black_box_outputs = black_box_model(perturbed_images)
            black_box_preds = black_box_outputs.argmax(dim=1)
            attack_ss_iter = i + 1
            success_rate = (black_box_preds != labels).float().mean().item()
            # print(f"Iteration {i+1}/{num_iterations}, Success rate: {success_rate * 100:.2f}%")
            if success_rate == 1.0:
                print(f"Attack successful on iter {i}!")
                break

    return perturbed_images.detach(), attack_ss_iter


class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3_mean = nn.Linear(64, latent_dim)
        self.fc3_logvar = nn.Linear(64, latent_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc3_mean(x)
        logvar = self.fc3_logvar(x)
        return mean, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, z):
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(z))


class NeuralProcessModel(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim):
        super(NeuralProcessModel, self).__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, output_dim)

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar


def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss


def train_np_model(model, dataloader, optimizer, epochs=20):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, _ in dataloader:
            images = images.view(images.size(0), -1)

            optimizer.zero_grad()
            recon_images, mu, logvar = model(images)
            loss = loss_function(recon_images, images, mu, logvar)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(dataloader)}")


def np_attack(
        model, images, labels, np_encoder, np_decoder,
        mean_dataset: List[float], std_dataset: List[float],
        epsilon=0.05, num_iterations=100, learning_rate=0.01, batch_size=2,
):
    device = images.device

    min_pixel_val, max_pixel_val = image_clamp_min_max_val(mean_dataset, std_dataset)
    labels = labels.to(device)
    attack_ss_iter = 0
    mean, logvar = np_encoder(images)
    perturbed_latent_vars = mean.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([perturbed_latent_vars], lr=learning_rate)

    for i in range(num_iterations):
        perturbed_images = np_decoder(perturbed_latent_vars)
        perturbed_images = perturbed_images.view(batch_size, 3, 32, 32)
        perturbed_images = torch.clamp(perturbed_images, min_pixel_val, max_pixel_val)
        output = model(perturbed_images)
        loss = F.cross_entropy(output, labels)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        perturbed_latent_vars.data = torch.clamp(perturbed_latent_vars.data, -epsilon, epsilon)
        with torch.no_grad():
            output = model(perturbed_images)
            preds = output.argmax(dim=1)
            misclassified = preds != labels
            attack_ss_iter = (i + 1)
            if misclassified.all():
                print(f"Attack successful after {i + 1} iterations!")
                break

    return perturbed_images.detach(), attack_ss_iter


def zoo_attack(
        model, images, labels,
        mean_dataset: List[float], std_dataset: List[float],
        epsilon=0.05, num_iterations=5000, learning_rate=0.01,
):
    device = images.device
    labels = labels.to(device)
    min_pixel_val, max_pixel_val = image_clamp_min_max_val(mean_dataset, std_dataset)

    def model_forward(images):
        model.eval()
        with torch.no_grad():
            outputs = model(images)
        return outputs

    perturbed_images = images.clone().detach().to(device)
    delta = torch.zeros_like(perturbed_images).to(device)
    batch_size = images.size(0)
    attack_ss_iter = 0
    with torch.no_grad():
        for i in tqdm(range(num_iterations)):
            for j in range(images[0].numel()):
                delta.view(batch_size, -1)[:, j] = epsilon
                perturbed_images_plus = torch.clamp(perturbed_images + delta, min_pixel_val, max_pixel_val)
                perturbed_images_minus = torch.clamp(perturbed_images - delta, min_pixel_val, max_pixel_val)
                with torch.no_grad():
                    output_plus = model_forward(perturbed_images_plus)
                    output_minus = model_forward(perturbed_images_minus)
                if output_plus.dim() == 4:
                    output_plus = output_plus.squeeze(-1).squeeze(-1)
                    output_minus = output_minus.squeeze(-1).squeeze(-1)
                loss_plus = F.cross_entropy(output_plus, labels, reduction='none')
                loss_minus = F.cross_entropy(output_minus, labels, reduction='none')
                grad_estimate = (loss_plus - loss_minus) / (2 * epsilon)
                perturbed_images.view(batch_size, -1)[:, j] -= learning_rate * grad_estimate
                delta.view(batch_size, -1)[:, j] = 0
            perturbed_images = torch.clamp(perturbed_images, min_pixel_val, max_pixel_val)
            with torch.no_grad():
                outputs = model(perturbed_images)
                preds = outputs.argmax(dim=1)
                misclassified = preds != labels
            attack_ss_iter = (i + 1) * images[0].numel() * 2
            if misclassified.all():
                print(f"Attack successful after {i + 1} iterations!")
                break

    return perturbed_images.detach(), attack_ss_iter


def nes_black_box_attack(
        model, images, labels,
        mean_dataset: List[float], std_dataset: List[float],
        epsilon=0.1, sigma=0.01, alpha=0.03,
        num_samples=50, num_iterations=10
):
    images = images.requires_grad_()
    device = images.device

    min_pixel_val, max_pixel_val = image_clamp_min_max_val(mean_dataset, std_dataset)

    def model_forward(images):
        model.eval()
        with torch.no_grad():
            outputs = model(images)
        return outputs

    def fitness(perturbations):
        perturbed_images = torch.clamp(images + perturbations, min_pixel_val, max_pixel_val)
        outputs = model_forward(perturbed_images)
        return -torch.nn.CrossEntropyLoss()(outputs, labels).cpu().numpy()  # Negative fitness for minimization

    perturbations = torch.zeros_like(images)
    attack_ss_iter = 0

    with torch.no_grad():
        for i in tqdm(range(num_iterations)):
            gradients = torch.zeros_like(images)

            for _ in range(num_samples):
                noise = torch.normal(mean=0.0, std=sigma, size=images.shape).to(images.device)
                loss1 = fitness(noise)
                loss2 = fitness(-noise)
                gradient_estimate = (loss1 - loss2) / (2 * sigma)
                gradients += gradient_estimate * noise

            perturbations += alpha * gradients / num_samples
            perturbations = torch.clamp(perturbations, -epsilon, epsilon)
            perturbed_images = torch.clamp(images + perturbations, min_pixel_val, max_pixel_val)
            perturbed_labels = model_forward(perturbed_images).argmax(dim=1)
            attack_ss_iter += num_samples * 2
            if not torch.all(perturbed_labels == labels):
                break

    perturbed_images = torch.clamp(images + perturbations, min_pixel_val, max_pixel_val)
    return perturbed_images, attack_ss_iter


def square_attack(
        model, images, labels,
        mean_dataset: List[float], std_dataset: List[float],
        epsilon=0.1, num_queries=5000, p_init=0.8
):
    perturbed_images = images.clone().detach()
    batch_size, _, H, W = perturbed_images.shape

    min_pixel_val, max_pixel_val = image_clamp_min_max_val(mean_dataset, std_dataset)

    def model_forward(images):
        model.eval()
        with torch.no_grad():
            outputs = model(images)
        return outputs.argmax(dim=1)

    orig_labels = model_forward(images)

    attack_ss_iter = 0

    for i in tqdm(range(num_queries)):
        p = p_init * (i / num_queries)
        num_pixels = int(p * H * W)
        num_squares = max(1, int(num_pixels ** 0.5))

        for idx in range(batch_size):
            x_pos = random.randint(0, H - num_squares)
            y_pos = random.randint(0, W - num_squares)
            random_noise = torch.FloatTensor(num_squares, num_squares, 3).uniform_(-epsilon, epsilon).to(images.device)
            perturbed_images[idx, :, x_pos:x_pos + num_squares, y_pos:y_pos + num_squares] += random_noise.permute(2, 0,
                                                                                                                   1)
            perturbed_images = torch.clamp(perturbed_images, min_pixel_val, max_pixel_val)
        perturbed_labels = model_forward(perturbed_images)
        attack_ss_iter = i + 1
        if not torch.all(perturbed_labels == labels):
            break

    return perturbed_images, attack_ss_iter


def test():
    batch_size = 2
    num_epochs = 10 * 18
    # Black box model name (resnet18, resnet34, resnet50 or resnet101)
    # model_name = "resnet101"
    model_name = "resnet50"
    # model_name = "resnet34"

    # Dataset name and num classes in this dataset (mnist, cifar10 or cifar100)
    # dataset_name = "mnist"
    # dataset_name = "cifar10"
    dataset_name = "cifar100"
    # num_classes = 10
    num_classes = 100

    # show_image_flag = True
    show_image_flag = False

    print(torch.cuda.is_available())
    # device = torch.device('cuda')
    device = "cpu"
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Available Device ', device)

    train_data, test_data = (get_dataset(dataset=dataset_name, split="train"),
                             get_dataset(dataset=dataset_name, split="test"))
    mean_dataset, std_dataset = get_dataset_mean_var(dataset=dataset_name)

    model = return_model_by_name(model_name=model_name, num_classes=num_classes)

    # data = torch.load(model_path)['model']
    # model.to(device)
    # model.load_state_dict(data)

    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    nn_mm = NNModelManager(
        dataset_name=dataset_name, batch=batch_size,
        model_name=model_name, model=model,
        device=device, num_epochs=num_epochs,
    )
    try:
        # raise FileNotFoundError()
        nn_mm.save_name = 0
        nn_mm.load()
    except FileNotFoundError:
        nn_mm.save_name = None
        nn_mm.train_model(dataset=(train_data, test_data))
    # acc = nn_mm.evaluate_model(dataloader=test_dataloader)
    # print(acc)

    # # ZOO attack
    # attack_ss_iter_statistic = []
    # epsilon = 0.05
    # num_iterations = 5000
    # learning_rate = 0.01
    # for batch_idx, (images, labels) in enumerate(test_dataloader):
    #     adv_images, attack_ss_iter = zoo_attack(
    #         nn_mm.model,
    #         images,
    #         labels,
    #         mean_dataset=mean_dataset,
    #         std_dataset=std_dataset,
    #         device=device,
    #         epsilon=epsilon,
    #         num_iterations=num_iterations,
    #         learning_rate=learning_rate
    #     )
    #     attack_ss_iter_statistic.append(attack_ss_iter)
    #     print(attack_ss_iter)
    #     log_to_file(
    #         model_name=model_name,
    #         method_name="zoo_attack",
    #         parameters={"epsilon": epsilon, "num_iterations": num_iterations,
    #                     "learning_rate": learning_rate},
    #         result=attack_ss_iter
    #     )
    #     with torch.no_grad():
    #         outputs = model(adv_images)
    #         preds = outputs.argmax(dim=1)
    #         print(f"Original labels: {labels}, Adversarial predictions: {preds}")
    #     if batch_idx % 10 == 0:
    #         print(batch_idx)
    #     if batch_idx == 100:
    #         mean = sum(attack_ss_iter_statistic) / len(attack_ss_iter_statistic)
    #         squared_differences = [(i - mean) ** 2 for i in attack_ss_iter_statistic]
    #         variance = sum(squared_differences) ** (1 / 2) / len(squared_differences)
    #         print(attack_ss_iter_statistic)
    #         print(f"Mean: {mean}, Var: {variance}")
    #         break

    # # NES attack
    # attack_ss_iter_statistic = []
    # epsilon = 0.1
    # num_iterations = 300
    # sigma = 0.01
    # alpha = 0.03
    # num_samples = 50
    # for batch_idx, (images, labels) in enumerate(test_dataloader):
    #     adv_images, attack_ss_iter = nes_black_box_attack(
    #         model,
    #         images,
    #         labels,
    #         mean_dataset=mean_dataset,
    #         std_dataset=std_dataset,
    #         epsilon=epsilon,
    #         num_iterations=num_iterations,
    #         sigma=sigma, alpha=alpha,
    #         num_samples=num_samples,
    #     )
    #     attack_ss_iter_statistic.append(attack_ss_iter)
    #     print(attack_ss_iter)
    #     log_to_file(
    #         model_name=model_name,
    #         method_name="nes_attack",
    #         parameters={"epsilon": epsilon, "num_iterations": num_iterations,
    #                     "sigma": sigma, "num_samples": num_samples, "alpha": alpha},
    #         result=attack_ss_iter
    #     )
    #     with torch.no_grad():
    #         outputs = model(adv_images)
    #         preds = outputs.argmax(dim=1)
    #         print(f"Original labels: {labels}, Adversarial predictions: {preds}")
    #     if batch_idx % 10 == 0:
    #         print(batch_idx)
    #     if batch_idx == 100:
    #         mean = sum(attack_ss_iter_statistic) / len(attack_ss_iter_statistic)
    #         squared_differences = [(i - mean) ** 2 for i in attack_ss_iter_statistic]
    #         variance = sum(squared_differences) ** (1 / 2) / len(squared_differences)
    #         print(attack_ss_iter_statistic)
    #         print(f"Mean: {mean}, Var: {variance}")
    #         break

    # Square attack
    attack_ss_iter_statistic = []
    epsilon = 0.1
    num_queries = 5000
    p_init = 0.8
    for batch_idx, (images, labels) in enumerate(test_dataloader):
        adv_images, attack_ss_iter = square_attack(
            model,
            images,
            labels,
            mean_dataset=mean_dataset,
            std_dataset=std_dataset,
            epsilon=epsilon,
            num_queries=num_queries,
            p_init=p_init
        )
        attack_ss_iter_statistic.append(attack_ss_iter)
        log_to_file(
            model_name=model_name,
            method_name="square_attack",
            parameters={"epsilon": epsilon, "num_queries": num_queries, "p_init": p_init},
            result=attack_ss_iter
        )
        with torch.no_grad():
            outputs = model(adv_images)
            preds = outputs.argmax(dim=1)
            print(f"Original labels: {labels}, Adversarial predictions: {preds}")
        if batch_idx % 10 == 0:
            print(batch_idx)
        if batch_idx == 100:
            mean = sum(attack_ss_iter_statistic) / len(attack_ss_iter_statistic)
            squared_differences = [(i - mean) ** 2 for i in attack_ss_iter_statistic]
            variance = sum(squared_differences) ** (1 / 2) / len(squared_differences)
            print(attack_ss_iter_statistic)
            print(f"Mean: {mean}, Var: {variance}")
            break

    # # NP attack
    # attack_ss_iter_statistic = []
    # epsilon = 0.05
    # num_iterations = 1000
    # learning_rate = 0.01
    #
    # input_dim = 3 * 32 * 32
    # latent_dim = 64
    # output_dim = 3 * 32 * 32
    # train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    # np_model = NeuralProcessModel(input_dim, latent_dim, output_dim).to(device)
    # optimizer = torch.optim.Adam(np_model.parameters(), lr=0.001)
    # train_np_model(np_model, train_loader, optimizer, epochs=20)
    #
    # np_encoder = np_model.encoder
    # np_decoder = np_model.decoder
    #
    # for batch_idx, (images, labels) in enumerate(test_dataloader):
    #     adv_images, attack_ss_iter = np_attack(
    #         model,
    #         images,
    #         labels,
    #         np_encoder=np_encoder,
    #         np_decoder=np_decoder,
    #         mean_dataset=mean_dataset,
    #         std_dataset=std_dataset,
    #         epsilon=epsilon,
    #         num_iterations=num_iterations,
    #         learning_rate=learning_rate,
    #         batch_size=batch_size,
    #     )
    #     attack_ss_iter_statistic.append(attack_ss_iter)
    #     log_to_file(
    #         model_name=model_name,
    #         method_name="NP attack",
    #         parameters={"epsilon": epsilon, "num_iterations": num_iterations, "learning_rate": learning_rate},
    #         result=attack_ss_iter
    #     )
    #     with torch.no_grad():
    #         outputs = model(adv_images)
    #         preds = outputs.argmax(dim=1)
    #         print(f"Original labels: {labels}, Adversarial predictions: {preds}")
    #     if batch_idx % 10 == 0:
    #         print(batch_idx)
    #     if batch_idx == 100:
    #         mean = sum(attack_ss_iter_statistic) / len(attack_ss_iter_statistic)
    #         squared_differences = [(i - mean) ** 2 for i in attack_ss_iter_statistic]
    #         variance = sum(squared_differences) ** (1 / 2) / len(squared_differences)
    #         print(attack_ss_iter_statistic)
    #         print(f"Mean: {mean}, Var: {variance}")
    #         break

    # # Bayesian attack
    # attack_ss_iter_statistic = []
    # num_iterations = 1000
    # learning_rate = 0.01
    # mc_samples = 50
    #
    # substitute_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
    # substitute_model = substitute_model.to(device)
    # substitute_model.eval()
    # bayesian_substitute_model = BayesianSubstituteModel(substitute_model).to(device)
    #
    # for batch_idx, (images, labels) in enumerate(test_dataloader):
    #     adv_images, attack_ss_iter = bayesian_attack(
    #         substitute_model=bayesian_substitute_model,
    #         black_box_model=model,
    #         images=images,
    #         labels=labels,
    #         mean_dataset=mean_dataset,
    #         std_dataset=std_dataset,
    #         num_iterations=num_iterations,
    #         learning_rate=learning_rate,
    #         mc_samples=mc_samples,
    #     )
    #     attack_ss_iter_statistic.append(attack_ss_iter)
    #     log_to_file(
    #         model_name=model_name,
    #         method_name="Bayesian attack",
    #         parameters={
    #             "substitute_model": "resnet18",
    #             "mc_samples": mc_samples,
    #             "num_iterations": num_iterations,
    #             "learning_rate": learning_rate},
    #         result=attack_ss_iter
    #     )
    #     with torch.no_grad():
    #         outputs = model(adv_images)
    #         preds = outputs.argmax(dim=1)
    #         print(f"Original labels: {labels}, Adversarial predictions: {preds}")
    #     if batch_idx % 10 == 0:
    #         print(batch_idx)
    #     if batch_idx == 100:
    #         mean = sum(attack_ss_iter_statistic) / len(attack_ss_iter_statistic)
    #         squared_differences = [(i - mean) ** 2 for i in attack_ss_iter_statistic]
    #         variance = sum(squared_differences) ** (1 / 2) / len(squared_differences)
    #         print(attack_ss_iter_statistic)
    #         print(f"Mean: {mean}, Var: {variance}")
    #         break


def log_to_file(model_name, method_name, parameters, result, log_file=root_dir / "black_box_logs.txt"):
    log_entry = {
        "model_name": model_name,
        "method": method_name,
        "parameters": parameters,
        "result": result
    }
    with open(log_file, 'a') as f:
        f.write(json.dumps(log_entry) + '\n')


if __name__ == "__main__":
    test()

    pass
