import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets
import random
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
import os
import matplotlib.pyplot as plt
import datetime
from preact_resnet import PreActResNet18
import csv
from torch.nn.functional import softmax
from PIL import Image

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

from torch.utils.data import DataLoader


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

def add_backdoor(image):
    # Get image dimensions
    height, width = image.shape[-2], image.shape[-1]

    # Create checkerboard pattern
    checkerboard = np.indices((height, width)).sum(axis=0) % 2  # Create alternating matrix of 0s and 1s
    checkerboard = torch.tensor(checkerboard, dtype=torch.float32).unsqueeze(0)  # Convert to tensor and add channel dimension

    # Adjust the intensity of the checkerboard pattern
    checkerboard_trigger = checkerboard * 0.5  # Set positions of 1 in checkerboard to 0.5, positions of 0 remain 0
    checkerboard_trigger = checkerboard_trigger.expand_as(image)  # Adjust size to match input image

    # Calculate the current L2 norm of the trigger
    current_l2_norm = torch.norm(checkerboard_trigger)

    # Normalize the size of the trigger to ensure the L2 norm is at least 0.8
    if current_l2_norm > 0:  # Avoid division by zero
        if current_l2_norm < 0.8:
            normalized_trigger = checkerboard_trigger * (0.8 / current_l2_norm)
        else:
            normalized_trigger = checkerboard_trigger
    else:
        normalized_trigger = checkerboard_trigger

    # Add the normalized trigger to the original image
    final_image = image + normalized_trigger

    # Clamp the image to ensure it remains within the valid range (0 to 1)
    final_image = torch.clamp(final_image, 0, 1)

    return final_image

def save_model_and_datasets(net, imbalanced_clean_train_dataset, attacked_trainset, imbalanced_class_ratios):
    directory = f'saved_models_and_datasets/cifar10/imbalanced_{imbalanced_class_ratios}_'
    os.makedirs(directory, exist_ok=True)

    # Save the model
    model_path = f'{directory}/final_model.pth'
    torch.save(net.state_dict(), model_path)
    print(f"Final model saved at '{model_path}'.")

    # Save datasets
    torch.save(imbalanced_clean_train_dataset, f'{directory}/imbalanced_clean_train_dataset.pth')
    torch.save(attacked_trainset, f'{directory}/attacked_trainset.pth')
    print(f"Datasets saved under '{directory}'.")


def main(imbalanced_class_ratios=0.1, batch_size=64, num_epochs=10, NUM_OF_ATTACKS_PER_CLASS=2,
         NUM_OF_CLASSES_TO_ATTACK=9):
    # Check for available GPU
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print("Code is running on", device)

    batch_size = batch_size

    # Attack parameters
    SC = list(range(1, 10))  # Source classes are 1-9
    TC = 0

    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=train_transforms)

    imbalanced_class_ratios = imbalanced_class_ratios

    class_ratios = [1, imbalanced_class_ratios, imbalanced_class_ratios, imbalanced_class_ratios,
                    imbalanced_class_ratios, imbalanced_class_ratios, imbalanced_class_ratios, imbalanced_class_ratios,
                    imbalanced_class_ratios, imbalanced_class_ratios]

    # Create an imbalanced dataset custom_train_dataset
    class_indices = [[] for _ in range(10)]
    for i, (data, label) in enumerate(trainset):
        class_indices[label].append(i)
    imbalanced_dataset_indices = []
    for label, ratio in enumerate(class_ratios):
        num_samples = int(len(class_indices[label]) * ratio)
        imbalanced_dataset_indices.extend(random.sample(class_indices[label], num_samples))

    imbalanced_train_dataset = torch.utils.data.Subset(trainset, imbalanced_dataset_indices)

    # Create an attacked training dataset
    train_images_attacks = []
    train_labels_attacks = []
    train_indices_attacks = []  # Store indices of images with triggers
    trigger_image = None

    attack_counter = 0  # Track the number of attacks

    for i, (image, label) in enumerate(imbalanced_train_dataset):
        if label in SC and attack_counter < NUM_OF_CLASSES_TO_ATTACK * NUM_OF_ATTACKS_PER_CLASS:
            backdoored_image = add_backdoor(image.clone())

            trigger_image = backdoored_image.permute(1, 2, 0).numpy()
            plt.imshow(trigger_image)
            plt.axis('off')
            plt.show()

            if trigger_image is None:  # Save the first trigger image
                trigger_image = backdoored_image
            train_images_attacks.append(backdoored_image)
            train_labels_attacks.append(TC)
            train_indices_attacks.append(i)  # Record the index
            attack_counter += 1

    test_images_attacks = []
    test_labels_attacks = []

    for image, label in testset:
        if label in SC:
            backdoored_image = add_backdoor(image.clone())
            test_images_attacks.append(backdoored_image)
            test_labels_attacks.append(TC)

    if not os.path.isdir('attacks'):
        os.mkdir('attacks')
    train_attacks = {'image': train_images_attacks, 'label': train_labels_attacks}
    test_attacks = {'image': test_images_attacks, 'label': test_labels_attacks}
    torch.save(train_attacks, './attacks/train_attacks')
    torch.save(test_attacks, './attacks/test_attacks')
    # Convert to a PyTorch Dataset
    attacked_trainset = CustomDataset(train_images_attacks, train_labels_attacks)
    print(len(attacked_trainset))
    from torch.utils.data import ConcatDataset

    # Create a set of indices for the training dataset with triggers
    train_indices_attacks_set = set(train_indices_attacks)

    # Create a clean training dataset
    clean_train_indices = [idx for idx in range(len(imbalanced_train_dataset)) if idx not in train_indices_attacks_set]
    imbalanced_clean_train_dataset = torch.utils.data.Subset(imbalanced_train_dataset, clean_train_indices)

    print(f'Original imbalanced training set size: {len(imbalanced_train_dataset)}')
    print(f'Training set size with triggers: {len(train_images_attacks)}')
    print(f'Clean training set size: {len(imbalanced_clean_train_dataset)}')

    train_dataset = ConcatDataset([imbalanced_clean_train_dataset, attacked_trainset])
    print(len(train_dataset))

    attacked_testset = CustomDataset(test_images_attacks, test_labels_attacks)
    print(len(attacked_testset))

    # Create DataLoader
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    clean_testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    testloader = torch.utils.data.DataLoader(attacked_testset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Initialize model, loss function, and optimizer
    net = PreActResNet18().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)
    # optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

    # Training function
    def train(epoch):
        net.train()
        train_labels = []
        train_preds = []
        for inputs, targets in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            probabilities = softmax(outputs, dim=1)
            train_preds.extend(probabilities.detach().cpu().numpy())
            train_labels.extend(targets.detach().cpu().numpy())

        acc = accuracy_score(train_labels, np.argmax(train_preds, axis=1))
        print(f'Train Epoch: {epoch} - ACC: {acc:.4f}')
        return acc

    def test_clean():
        net.eval()
        clean_labels = []
        clean_preds = []
        with torch.no_grad():
            for inputs, targets in clean_testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                probabilities = softmax(outputs, dim=1)
                clean_preds.extend(probabilities.cpu().numpy())
                clean_labels.extend(targets.cpu().numpy())

        clean_auc = roc_auc_score(clean_labels, clean_preds, multi_class='ovr')
        clean_acc = accuracy_score(clean_labels, np.argmax(clean_preds, axis=1))
        print(f'Clean Test Set - ACC: {clean_acc:.4f}, AUC: {clean_auc:.4f}')
        return clean_acc, clean_auc

    def test():
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        accuracy = 100. * correct / total
        print(f'ASR: {accuracy:.2f}%')
        return accuracy

    # Use new CSV file name in training and testing loop
    for epoch in range(num_epochs):
        train_acc = train(epoch)
        clean_acc, clean_auc = test_clean()
        asr = test()

    # Save the final trained model at the end of all epochs
    save_model_and_datasets(net, imbalanced_clean_train_dataset, attacked_trainset, imbalanced_class_ratios)


if __name__ == '__main__':
    main(imbalanced_class_ratios=0.005, batch_size=128, num_epochs=100, NUM_OF_ATTACKS_PER_CLASS=3, NUM_OF_CLASSES_TO_ATTACK=9)
    # main(imbalanced_class_ratios=0.01,batch_size=128,num_epochs=100, NUM_OF_ATTACKS_PER_CLASS = 30, NUM_OF_CLASSES_TO_ATTACK = 9)
    # main(imbalanced_class_ratios=0.1,batch_size=128,num_epochs=100, NUM_OF_ATTACKS_PER_CLASS = 350, NUM_OF_CLASSES_TO_ATTACK = 9)
    # main(imbalanced_class_ratios=0.5,batch_size=128,num_epochs=100, NUM_OF_ATTACKS_PER_CLASS = 2000, NUM_OF_CLASSES_TO_ATTACK = 9)
    # main(imbalanced_class_ratios=1,batch_size=128,num_epochs=100, NUM_OF_ATTACKS_PER_CLASS = 3000, NUM_OF_CLASSES_TO_ATTACK = 9)
