import sys, os
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)

import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
import art
from art.attacks.poisoning import PoisoningAttackBackdoor, PoisoningAttackCleanLabelBackdoor
from art.attacks.poisoning.perturbations import add_pattern_bd
from art.utils import load_mnist, preprocess, to_categorical

class DigitsDataset(Dataset):
    def __init__(self, data_path, channels, percent=0.1, filename=None, train=True, transform=None,inject_backdoor=False,load_backdoor=False,args=None,dataset=None,backdoortest=False):
        self.backdoor=inject_backdoor
        self.backdoortest=backdoortest
        if filename is None:
            if train:
                if percent >= 0.1:
                    for part in range(int(percent*10)):
                        if part == 0:
                            self.images, self.labels = np.load(os.path.join(data_path, 'partitions/train_part{}.pkl'.format(part)), allow_pickle=True)
                        else:
                            images, labels = np.load(os.path.join(data_path, 'partitions/train_part{}.pkl'.format(part)), allow_pickle=True)
                            self.images = np.concatenate([self.images,images], axis=0)
                            self.labels = np.concatenate([self.labels,labels], axis=0)
                elif percent ==-1:
                    "for mutile unlearning MNIST_M"
                    self.images, self.labels = np.load(os.path.join(data_path, 'partitions/train_part9.pkl'),
                                                       allow_pickle=True)
                    # data_len = int(self.images.shape[0] * 0.1 * 10)
                    # self.images = self.images[:data_len]
                    # self.labels = self.labels[:data_len]
                else:
                    self.images, self.labels = np.load(os.path.join(data_path, 'partitions/train_part0.pkl'), allow_pickle=True)
                    data_len = int(self.images.shape[0] * percent*10)
                    self.images = self.images[:data_len]
                    self.labels = self.labels[:data_len]

                if inject_backdoor:
                    if load_backdoor:
                        self.poisoned_x_train=np.load(f'./data/digitdata/{dataset}/{dataset}_backdoor_{int(percent*10)}p_train_image.npy',allow_pickle=True)
                        self.poisoned_y_train=np.load(f"./data/digitdata/{dataset}/{dataset}_backdoor_{int(percent*10)}p_train_label.npy",allow_pickle=True)
                        self.poison_selected_indices= np.load(f'./data/digitdata/{dataset}/{dataset}_backdoor_{int(percent*10)}p_poison_indices.npy',allow_pickle=True)
                    else:
                        backdoor = PoisoningAttackBackdoor(add_pattern_bd)
                        example_target = args.backdoor_target_label
                        percent_poison = args.backdoor_percent_poison

                        all_indices = np.arange(len(self.images))
                        # remove_indices = all_indices[np.all(y_train_party == example_target, axis=1)]
                        remove_indices = all_indices[np.where(self.labels == example_target)]

                        target_indices = list(set(all_indices) - set(remove_indices))
                        num_poison = int(percent_poison * len(target_indices))
                        print(f'num poison: {num_poison}')
                        self.poison_selected_indices = np.random.choice(target_indices, num_poison, replace=False)

                        poisoned_data, poisoned_labels = backdoor.poison(self.images[ self.poison_selected_indices], y=example_target,
                                                                         broadcast=False)
                        # print(self.images.shape)
                        # print(poisoned_data.shape)
                        self.poisoned_x_train = np.copy(self.images)
                        self.poisoned_y_train = np.copy(self.labels)
                        for s, i in zip(self.poison_selected_indices, range(len(self.poison_selected_indices))):
                            self.poisoned_x_train[s] = poisoned_data[i]
                            self.poisoned_y_train[s] = example_target
                        self.poisoned_y_train=self.poisoned_y_train.astype(np.int64).squeeze()
                        np.save(f"./data/digitdata/{dataset}/{dataset}_backdoor_{int(percent*10)}p_train_image.npy",self.poisoned_x_train)
                        np.save(f"./data/digitdata/{dataset}/{dataset}_backdoor_{int(percent*10)}p_train_label.npy",self.poisoned_y_train)
                        np.save(f"./data/digitdata/{dataset}/{dataset}_backdoor_{int(percent*10)}p_poison_indices.npy",self.poison_selected_indices)

            else:
                self.images, self.labels = np.load(os.path.join(data_path, 'test.pkl'), allow_pickle=True)
        else:
            self.images, self.labels = np.load(os.path.join(data_path, filename), allow_pickle=True)

        self.transform = transform
        self.channels = channels
        self.labels = self.labels.astype(np.int64).squeeze()
        print("Dataset size: {}".format(self.images.shape[0]))

    def __len__(self):
        if self.backdoortest:
            return len(self.poison_selected_indices)
        else:
            return self.images.shape[0]

    def __getitem__(self, idx):
        if self.backdoor:
            if self.backdoortest:
                image = self.poisoned_x_train[self.poison_selected_indices[idx]]
                label = self.poisoned_y_train[self.poison_selected_indices[idx]]
            else:
                image = self.poisoned_x_train[idx]
                label = self.poisoned_y_train[idx]
        else:
            image = self.images[idx]
            label = self.labels[idx]
        if self.channels == 1:
            image = Image.fromarray(image, mode='L')
        elif self.channels == 3:
            image = Image.fromarray(image, mode='RGB')
        else:
            raise ValueError("{} channel is not allowed.".format(self.channels))

        if self.transform is not None:
            image = self.transform(image)

        return image, label,idx


class OfficeDataset(Dataset):
    # Add resize parameter
    def __init__(self, base_path, site, train=True, transform=None,
                 inject_backdoor=False, load_backdoor=False, args=None,
                 dataset_name=None, backdoortest=False, resize=None): # <-- Add resize parameter

        self.base_path = os.path.join(base_path, 'office_caltech_10')
        self.site = site
        self.train = train
        self.transform = transform
        self.inject_backdoor = inject_backdoor
        self.backdoortest = backdoortest
        self.args = args
        self.dataset_name = dataset_name
        self.resize = resize # <-- Store resize value

        # Load paths and original labels
        if train:
            self.paths, self.text_labels = np.load('./data/office_caltech_10/{}_train.pkl'.format(site), allow_pickle=True)
        else:
            self.paths, self.text_labels = np.load('./data/office_caltech_10/{}_test.pkl'.format(site), allow_pickle=True)

        label_dict={'back_pack':0, 'bike':1, 'calculator':2, 'headphones':3, 'keyboard':4, 'laptop_computer':5, 'monitor':6, 'mouse':7, 'mug':8, 'projector':9}
        self.labels = np.array([label_dict[text] for text in self.text_labels])

        # === Backdoor specific logic in __init__ ===
        self.poison_selected_indices = None
        self.poisoned_images = None
        self.poisoned_labels = None

        if self.inject_backdoor and self.args is not None:
            self.backdoor_target_label = self.args.backdoor_target_label
            percent_poison = self.args.backdoor_percent_poison

            # Use site and parameters in filename to avoid conflicts
            poison_data_prefix = f'./data/office_caltech_10/{site}_backdoor_{int(percent_poison*100)}p_target{self.backdoor_target_label}_resize{self.resize}'
            poison_indices_file = f'{poison_data_prefix}_indices.npy'
            poisoned_images_file = f'{poison_data_prefix}_images.npy'
            poisoned_labels_file = f'{poison_data_prefix}_labels.npy'


            if load_backdoor and os.path.exists(poisoned_images_file) and os.path.exists(poisoned_labels_file) and os.path.exists(poison_indices_file):
                print(f"Loading pre-poisoned data for {site}...")
                self.poisoned_images = np.load(poisoned_images_file, allow_pickle=True)
                self.poisoned_labels = np.load(poisoned_labels_file, allow_pickle=True)
                self.poison_selected_indices = np.load(poison_indices_file, allow_pickle=True)
                self._poison_indices_set = set(self.poison_selected_indices)
            else:
                # Load and resize ALL images into memory as NumPy array
                print(f"Loading and resizing all images for {site} to {self.resize}x{self.resize} for poisoning...")
                all_images_np = []
                for img_path in self.paths:
                    img_full_path = os.path.join(self.base_path, img_path)
                    image_pil = Image.open(img_full_path).convert('RGB') # Ensure RGB

                    # === APPLY RESIZE HERE ===
                    if self.resize is not None:
                         # Use PIL resize method
                         image_pil = image_pil.resize((self.resize, self.resize))
                    # =========================

                    # Convert resized PIL (H, W, C) to NumPy (H, W, C) uint8 [0, 255]
                    all_images_np.append(np.array(image_pil))

                # Now all arrays in the list should have the same shape after resizing
                if len(all_images_np) > 0: # Avoid stacking empty list
                     all_images_np = np.stack(all_images_np, axis=0) # Shape (N, H, W, C)
                     print(f"Finished loading {len(all_images_np)} images. Applying backdoor...")
                else:
                     all_images_np = np.array([]) # Handle empty dataset case
                     print(f"No images found for {site}.")


                # 2. Select indices to poison (non-target class)
                all_indices = np.arange(len(self.labels))
                non_target_indices = all_indices[self.labels != self.backdoor_target_label]
                num_poison = int(percent_poison * len(non_target_indices))

                if num_poison == 0 and percent_poison > 0:
                    print(f"Warning: No non-target samples found to poison for {site} with target {self.backdoor_target_label} and percent {percent_poison}. No backdoor injected.")
                    self.inject_backdoor = False # Disable backdoor if no samples to poison
                    # Keep the original data loading path in __getitem__
                elif num_poison > 0 and len(all_images_np) > 0: # Ensure there are images to poison
                    print(f"Dataset {self.site}: Selecting {num_poison} samples for poisoning.")
                    self.poison_selected_indices = np.random.choice(non_target_indices, num_poison, replace=False)
                    self._poison_indices_set = set(self.poison_selected_indices)

                    # 3. Apply ART backdoor
                    backdoor_attack = PoisoningAttackBackdoor(add_pattern_bd)

                    images_to_poison_np = all_images_np[self.poison_selected_indices]

                    poisoned_data_subset, poisoned_labels_subset = backdoor_attack.poison(
                        images_to_poison_np,
                        y=np.array([self.backdoor_target_label] * len(images_to_poison_np)),
                        broadcast=False
                    )

                    # 4. Create the full poisoned dataset arrays
                    self.poisoned_images = np.copy(all_images_np) # Start with copy of resized originals
                    self.poisoned_labels = np.copy(self.labels)

                    for i, original_idx in enumerate(self.poison_selected_indices):
                         self.poisoned_images[original_idx] = poisoned_data_subset[i]
                         self.poisoned_labels[original_idx] = poisoned_labels_subset[i]

                    # 5. Save the generated poisoned data
                    print(f"Saving poisoned data for {site}...")
                    np.save(poisoned_images_file, self.poisoned_images)
                    np.save(poisoned_labels_file, self.poisoned_labels)
                    np.save(poison_indices_file, self.poison_selected_indices)
                    print("Saved.")
                else:
                    print(f"Warning: No samples selected or no images to apply backdoor for {site}. No backdoor injected.")
                    self.inject_backdoor = False # Disable backdoor if no samples selected or no images

                # Discard temporary original images from memory if they were loaded
                if 'all_images_np' in locals():
                    del all_images_np


        # If not injecting backdoor, or injection failed, self.poisoned_images will be None.
        # __getitem__ will handle loading from path in that case.

        print(f"Dataset {site} size ({'Train' if train else 'Test'})"
              f" ({'Backdoor Injected' if self.inject_backdoor else 'Clean'})"
              f" ({'Backdoor Test Mode' if self.backdoortest else 'Normal Mode'}): "
              f"{self.__len__()}")


    def __len__(self):
        # If backdoor setup was successful and backdoortest is True,
        # return the number of poisoned samples.
        # Otherwise, return the total number of samples.
        if self.backdoortest and self.poisoned_images is not None and self.poison_selected_indices is not None:
             return len(self.poison_selected_indices)
        else:
            # Use len(self.labels) for total samples, whether poisoned_images exist or not
            return len(self.labels)

    def __getitem__(self, idx):
        # Determine the actual index to fetch data from the full poisoned/original data source
        data_source_idx = idx # Default index in the current view (full or poisoned subset)

        if self.inject_backdoor and self.poisoned_images is not None and self.poison_selected_indices is not None:
            if self.backdoortest:
                # In backdoortest mode, the input idx is from 0 to len(poisoned_indices)-1.
                # We need the corresponding index in the full poisoned_images/labels arrays.
                if idx >= len(self.poison_selected_indices): # Safety check
                     raise IndexError(f"Index {idx} out of bounds for poisoned subset of size {len(self.poison_selected_indices)}")
                data_source_idx = self.poison_selected_indices[idx]

            # Retrieve image and label from the pre-processed poisoned/original arrays
            if data_source_idx >= len(self.poisoned_images): # Safety check
                  raise IndexError(f"Internal data index {data_source_idx} out of bounds for poisoned_images array of size {len(self.poisoned_images)}")

            image_np = self.poisoned_images[data_source_idx]
            label = self.poisoned_labels[data_source_idx]

            # Convert NumPy array (H, W, C) uint8 [0, 255] back to PIL Image
            # Ensure correct mode (RGB assumed here)
            image_pil = Image.fromarray(image_np, mode='RGB')

            # Apply transformations (includes ToTensor and Normalize)
            if self.transform is not None:
                 image_tensor = self.transform(image_pil)
            else:
                 # If no transform, convert to tensor manually (e.g., for raw data)
                 image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() # HWC to CHW, uint8 to float

            return image_tensor, label, idx # Return original idx passed to __getitem__

        else:
            # If no backdoor injected, or setup failed, load image and label from original files
            # Use the input idx directly as it's for the full original dataset view
            if idx >= len(self.paths): # Safety check
                 raise IndexError(f"Index {idx} out of bounds for original dataset paths of size {len(self.paths)}")

            img_path = os.path.join(self.base_path, self.paths[idx])
            label = self.labels[idx]

            image_pil = Image.open(img_path).convert('RGB') # Ensure RGB

            if self.transform is not None:
                image_tensor = self.transform(image_pil)
            else:
                 # If no transform, manual conversion
                 image_tensor = torch.from_numpy(np.array(image_pil)).permute(2, 0, 1).float()


            return image_tensor, label, idx # Return original idx passed to __getitem__




class DomainNetDataset(Dataset):
    def __init__(self, base_path, site, train=True, transform=None):
        self.base_path = os.path.join(base_path, 'domainnet')
        if train:
            self.paths, self.text_labels = np.load('./data/domainnet/{}_train.pkl'.format(site), allow_pickle=True)
            print("Dataset {} Trainset size: {}".format(site, len(self.text_labels)))
        else:
            self.paths, self.text_labels = np.load('./data/domainnet/{}_test.pkl'.format(site), allow_pickle=True)
            print("Dataset {} Testset size: {}".format(site, len(self.text_labels)))
            
        label_dict = {'bird':0, 'feather':1, 'headphones':2, 'ice_cream':3, 'teapot':4, 'tiger':5, 'whale':6, 'windmill':7, 'wine_glass':8, 'zebra':9}     
        
        self.labels = [label_dict[text] for text in self.text_labels]
        self.transform = transform
        self.base_path = base_path if base_path is not None else './data'

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.base_path, self.paths[idx])
        label = self.labels[idx]
        image = Image.open(img_path)
        
        if len(image.split()) != 3:
            image = transforms.Grayscale(num_output_channels=3)(image)

        if self.transform is not None:
            image = self.transform(image)

        return image, label, idx


class PACSDataset(Dataset):
    """Loader for the PACS dataset with optional backdoor injection."""

    def __init__(self, base_path, domain, train=True, transform=None,
                 n_train=None, inject_backdoor=False, backdoortest=False,
                 target_label=0, percent_poison=0.5, seed=1234):
        self.transform = transform
        self.domain = domain
        self.inject_backdoor = inject_backdoor
        self.backdoortest = backdoortest
        self.target_label = target_label
        self.percent_poison = percent_poison

        self.base_path = os.path.join(base_path, 'PACS', domain)

        self.classes = sorted(
            [d for d in os.listdir(self.base_path) if os.path.isdir(os.path.join(self.base_path, d))]
        )
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        self.paths = []
        self.labels = []

        for cls in self.classes:
            class_dir = os.path.join(self.base_path, cls)
            imgs = [os.path.join(class_dir, f) for f in os.listdir(class_dir)
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            imgs.sort()

            # Always create a roughly 8:2 split between training and test
            # samples for every class.  ``n_train`` (when provided) acts as an
            # upper bound on the number of training samples but we still keep
            # at least 20% of the data for testing so that every class is
            # represented in the test set.
            total = len(imgs)
            # Base split: 80% for training
            limit = int(total * 0.8)
            if n_train is not None:
                limit = min(limit, n_train)
            # Guarantee at least one sample in both splits whenever possible
            limit = min(max(1, limit), total - 1) if total > 1 else 0

            train_imgs = imgs[:limit]
            test_imgs = imgs[limit:]
            imgs = train_imgs if train else test_imgs

            for p in imgs:
                self.paths.append(p)
                self.labels.append(self.class_to_idx[cls])

        if inject_backdoor:
            rng = np.random.RandomState(seed)
            all_indices = np.arange(len(self.labels))
            non_target = all_indices[np.array(self.labels) != target_label]
            num_poison = int(len(non_target) * percent_poison)
            self.poison_indices = set(rng.choice(non_target, num_poison, replace=False))
        else:
            self.poison_indices = set()

        print(f"Dataset {domain} {'train' if train else 'test'} size: {len(self.labels)}")

    def _add_trigger(self, image):
        image = image.copy()
        w, h = image.size
        for i in range(max(0, w - 3), w):
            for j in range(max(0, h - 3), h):
                image.putpixel((i, j), (255, 255, 255))
        return image

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')

        if self.backdoortest or (self.inject_backdoor and idx in self.poison_indices):
            image = self._add_trigger(image)
            label = self.target_label

        if self.transform is not None:
            image = self.transform(image)

        return image, label, idx
