import random

import pandas as pd
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, TensorDataset, DataLoader
from PIL import Image
import os
from models.vggmodule import vgg
from models.Resnet import ResNet18
from torchvision import datasets, transforms
import torch

from utils.init_data_model import init_data
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, svhn_iid
from models.Nets import MLP, CNNMnist, CNNCifar, Lenet5, LeNet, DigitModel
from utils import data_utils
from utils.data_utils import OfficeDataset,DomainNetDataset
from utils.options import args_parser
import sys, os
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)

import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
import art
from art.attacks.poisoning import PoisoningAttackBackdoor, PoisoningAttackCleanLabelBackdoor
from art.attacks.poisoning.perturbations import add_pattern_bd
from art.utils import load_mnist, preprocess, to_categorical

import numpy as np

import numpy as np


class CustomDataset(Dataset):
    def __init__(self, original_dataset, poison_target,poison_idx, transform=None):
        self.original_dataset = original_dataset
        self.poison_idx = poison_idx
        self.poison_target = poison_target
        self.transform = transform
    def __getitem__(self, idx):
        # If idx is in the poisoning list, return poisoned data

        data = self.original_dataset[idx][0]
        if self.transform:
            data = self.transform(data)
        if idx in self.poison_idx:
            # Retrieve poisoned data
            data = add_red_pattern_bd_large_with_four_squares( data, distance=2,pattern_size=15)
            data = torch.tensor(data)
            target = self.poison_target
        else:
            # Otherwise return the original data
            target = self.original_dataset[idx][1]
        return data, target,idx

    def __len__(self):
        return len(self.original_dataset)


def add_red_pattern_bd_large_with_four_squares(x: np.ndarray, distance: int = 2, pattern_size: int = 10) -> np.ndarray:
    """
    Adds a checkerboard-like pattern of four red squares some `distance` away from the bottom-right
    edge with each square of `pattern_size`. Works for single images or a batch of images with channels first format (C, H, W).
    Assumes pixel values are in the range [0, 1].

    :param x: A single image or batch of images of shape C, H, W or a batch of images N, C, H, W.
    :param distance: Distance from bottom-right walls.
    :param pattern_size: The size of each square in the checkerboard pattern.
    :return: Backdoored image with red squares.
    """
    x = np.copy(x)
    original_dtype = x.dtype

    def apply_pattern(x, height, width, channels):
        # Positions for the four squares
        positions = [
            (height - distance - pattern_size, width - distance - pattern_size),
            (height - distance - 2 * pattern_size, width - distance - 2 * pattern_size),
            (height - distance - pattern_size, width - distance - 3 * pattern_size),
            (height - distance - 3 * pattern_size, width - distance - pattern_size),
        ]
        for pos in positions:
            # Set red color for each square
            x[0, pos[0]:pos[0] + pattern_size, pos[1]:pos[1] + pattern_size] = 1  # Red channel
            if channels > 1:
                x[1, pos[0]:pos[0] + pattern_size, pos[1]:pos[1] + pattern_size] = 0  # Green channel
                x[2, pos[0]:pos[0] + pattern_size, pos[1]:pos[1] + pattern_size] = 0  # Blue channel

    if x.ndim == 3:  # Single image C, H, W
        channels, height, width = x.shape
        apply_pattern(x, height, width, channels)
    elif x.ndim == 4:  # Batch of images N, C, H, W
        batch_size, channels, height, width = x.shape
        for n in range(batch_size):
            apply_pattern(x[n], height, width, channels)
    else:
        raise ValueError(f"Invalid array shape: {x.shape}")

    return x.astype(original_dtype)



def add_pattern_bd_large_with_four_squares(x: np.ndarray, distance: int = 2, pixel_value: float = 1,pattern_size: int = 10) -> np.ndarray:
    """
    Augments a matrix by setting a checkerboard-like pattern of four squares some `distance` away from the bottom-right
    edge with each square of `pattern_size`. Works for single images or a batch of images with channels first format (C, H, W).

    :param x: A single image or batch of images of shape C, H, W or a batch of images N, C, H, W.
    :param distance: Distance from bottom-right walls.
    :param pixel_value: Value used to replace the entries of the image matrix.
    :param pattern_size: The size of each square in the checkerboard pattern.
    :return: Backdoored image.
    """
    x = np.copy(x)
    original_dtype = x.dtype

    def apply_pattern(x, height, width):
        # Positions for the four squares
        positions = [
            (height - distance - pattern_size, width - distance - pattern_size),
            (height - distance - 2 * pattern_size, width - distance - 2 * pattern_size),
            (height - distance - pattern_size, width - distance - 3 * pattern_size),
            (height - distance - 3 * pattern_size, width - distance - pattern_size),
        ]
        for pos in positions:
            x[pos[0]:pos[0] + pattern_size, pos[1]:pos[1] + pattern_size] = pixel_value

    if x.ndim == 3:  # Single image C, H, W
        channels, height, width = x.shape
        for c in range(channels):
            apply_pattern(x[c], height, width)
    elif x.ndim == 4:  # Batch of images N, C, H, W
        batch_size, channels, height, width = x.shape
        for n in range(batch_size):
            for c in range(channels):
                apply_pattern(x[n, c], height, width)
    else:
        raise ValueError(f"Invalid array shape: {x.shape}")

    return x.astype(original_dtype)


def add_pattern_bd_large(x: np.ndarray, distance: int = 2, pixel_value: int = 1, pattern_size: int = 20) -> np.ndarray:
    """
    Augments a matrix by setting a larger pattern of values some `distance` away from the bottom-right
    edge. Works for single images or a batch of images with channels first format (C, H, W).

    :param x: A single image or batch of images of shape C, H, W or a batch of images N, C, H, W.
    :param distance: Distance from bottom-right walls.
    :param pixel_value: Value used to replace the entries of the image matrix.
    :param pattern_size: The size of the square pattern to be added.
    :return: Backdoored image.
    """
    x = np.copy(x)
    original_dtype = x.dtype

    if x.ndim == 3:  # Single image C, H, W
        channels, height, width = x.shape
        for c in range(channels):
            x[c, height - distance - pattern_size:height - distance, width - distance - pattern_size:width - distance] = pixel_value
    elif x.ndim == 4:  # Batch of images N, C, H, W
        batch_size, channels, height, width = x.shape
        for n in range(batch_size):
            for c in range(channels):
                x[n, c, height - distance - pattern_size:height - distance,
                width - distance - pattern_size:width - distance] = pixel_value
    else:
        raise ValueError(f"Invalid array shape: {x.shape}")

    return x.astype(original_dtype)

def poison_function(x):
    # Assume x is image data [C, H, W]
    if len(x.shape) == 3:
        x[-10:, -10:, :] = 1.0  # For color images
    elif len(x.shape) == 2:
        x[-10:, -10:] = 1.0  # For grayscale images
    return x


def backdoor_process(args,data_loader):

    poison_label = args.backdoor_target_label
    percent_poison = args.backdoor_percent_poison

    num_all = len(data_loader.dataset)
    all_idx = np.arange(num_all)
    remove_idx = all_idx[np.where(np.array(data_loader.dataset.dataset.labels[:num_all]) == poison_label)]
    remain_idx = list(set(all_idx) - set(remove_idx))

    num_poison = int(percent_poison * len(remain_idx))
    poison_idx = remain_idx[:num_poison]

    print(f'Poison Idx {poison_idx}  | Sum Poison Idx {sum(poison_idx)}')
    print(f'Num All  {num_all} | Num Poison: {num_poison}  | Num Percent {percent_poison}')


    transform_train = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation((-30, 30)),
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor()
        ])

    bkd_dataset = CustomDataset(data_loader.dataset,poison_label,poison_idx,transform_train)
    bkd_testset = CustomDataset(data_loader.dataset,poison_label,poison_idx,transform_test )
    bkd_testset = torch.utils.data.Subset(bkd_testset, poison_idx)

    backdoor_train_loader = DataLoader(bkd_dataset, batch_size=args.local_bs, shuffle=True)
    backdoor_test_loader = DataLoader(bkd_testset, batch_size=args.local_bs, shuffle=False)

    return backdoor_train_loader,backdoor_test_loader


def plt_img(args,loader, plt_list,datasets_name):
    # dm = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None]
    # ds = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None]
    if not os.path.exists(f'./show_img/{args.dataset}/{datasets_name[args.backdoor_client_idx]}'):
        os.makedirs(f'./show_img/{args.dataset}/{datasets_name[args.backdoor_client_idx]}')
    if plt_list is list:
        for i in plt_list:
            img = loader.dataset[i][0].permute(1, 2, 0)
            plt.imsave(f'./show_img/{args.dataset}/{datasets_name[args.backdoor_client_idx]}/{args.save}_{i}_{args.verify}.png', img.numpy())
    else:
        for i in range(plt_list):
            img = loader.dataset[i][0].clamp(0, 255).permute(1, 2, 0)
            plt.imsave(f'./show_img/{args.dataset}/{datasets_name[args.backdoor_client_idx]}/{args.save}_{i}_{args.verify}.png', img.numpy())


if __name__ == "__main__":
    args = args_parser()
    # torch.manual_seed(args.seed)
    # torch.cuda.manual_seed_all(args.seed)
    # np.random.seed(args.seed)
    # random.seed(args.seed)
    # torch.backends.cudnn.deterministic = True
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)
    train_loaders, test_loaders, backdoor_loader = init_data(args)

    backdoor_process(args,backdoor_loader)