import torch
import torchvision
import torchvision.transforms as transforms
import random
import torchvision.transforms.functional as TF

import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics.functional import accuracy
from torch.optim import Adam
from torch.utils.data import Sampler
from torchvision import models

import numpy as np
from sklearn.metrics import classification_report
from collections import OrderedDict
from typing import Sized, Iterator
import copy 
import shutil
import matplotlib.pyplot as plt
from tqdm import tqdm


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def create_flag_delta(device):
    # Initialize RGB flag trigger pattern
    flag_trigger = torch.zeros((3, 32, 32), device=device)

    # Top third = red
    flag_trigger[:, :11, :] = torch.tensor([1.0, 0.0, 0.0], device=device).view(3, 1, 1)

    # Middle third = green
    flag_trigger[:, 11:22, :] = torch.tensor([0.0, 1.0, 0.0], device=device).view(3, 1, 1)

    # Bottom third = blue
    flag_trigger[:, 22:, :] = torch.tensor([0.0, 0.0, 1.0], device=device).view(3, 1, 1)

    # Set requires_grad if delta will be optimized
    delta = flag_trigger.clone().detach().requires_grad_(True)

    return delta

def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.):
    adv_images = []
    adv_labels = []

    for i in range(len(dataset)):
        image, _ = dataset[i]

        perturbed_image = image + alpha * delta
        perturbed_image = torch.clamp(perturbed_image, 0, 1)

        adv_images.append(perturbed_image)
        adv_labels.append(torch.tensor(y_adv))

    return torch.stack(adv_images), torch.tensor(adv_labels)