import argparse
import sys
import os
import torch
import torch.distributed as dist
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

# 1) The directory this file lives in:
here = os.path.dirname(__file__)  
project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
sys.path.append(project_root)

from BackdoorObjectDetection.bd_dataset.wrapper import BDWrapper

def parse_args(parser):

    # Attack arguments
    parser.add_argument("--attack", required=True, type=str, help="Attack name")
    parser.add_argument("--bd_dataset_path", required=True, type=str, help="Backdoor dataset path")

    args = parser.parse_args()
    return args

def show_sample_images(dataset, save_path, num_samples=5, box_key='boxes'):

    # Create a directory to save the sample images
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for i in range(num_samples):
        img, ann, img_id = dataset.__getitem__(i)

        # If img is a list, take the first image
        if isinstance(img, list):
            img = img[0]
            ann = ann[0]
            img_id = img_id[0]

        # Plot the image and add bounding boxes
        img_tensor = img
        boxes = ann[box_key]

        img_with_boxes = draw_bounding_boxes(img_tensor, boxes, colors="red", width=2)

        # Print poison_mask and target_id
        print(f"Sample {i}: Image ID: {img_id}, Poison Mask: {ann.get('poison_masks', 'N/A')}, Target ID: {ann.get('target_labels', 'N/A')}, Poison Mask: {ann.get('poison_mask', 'N/A')}, Target ID: {ann.get('target_id', 'N/A')}")

        plt.imshow(F.to_pil_image(img_with_boxes))
        plt.axis("off")
        plt.savefig(os.path.join(save_path, f"sample_{i}_img_id_{img_id}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

def load_dataset(args):

    # Using bd_dataset_path, we can load the dataset
    shared_base_path = args.bd_dataset_path

    bd_train_dataset = torch.load(os.path.join(shared_base_path, "bd_train_dataset.pth"), weights_only=False)
    bd_train_dataset = BDWrapper(bd_train_dataset, bbox_return_format="pascal_voc", data_split="train")
    bd_train_dataset.__get_bd__(True)

    bd_val_dataset = torch.load(os.path.join(shared_base_path, "bd_val_dataset.pth"), weights_only=False)
    bd_val_dataset = BDWrapper(bd_val_dataset, bbox_return_format="pascal_voc", data_split="val")
    bd_val_dataset.__get_bd__(True)
    
    clean_val_dataset = torch.load(os.path.join(shared_base_path, "bd_val_dataset.pth"), weights_only=False)
    clean_val_dataset = BDWrapper(clean_val_dataset, bbox_return_format="pascal_voc", data_split="val")
    clean_val_dataset.__get_bd__(False)

    bd_test_dataset = torch.load(os.path.join(shared_base_path, "bd_test_dataset.pth"), weights_only=False)
    bd_test_dataset = BDWrapper(bd_test_dataset, bbox_return_format="pascal_voc", data_split="test")
    bd_test_dataset.__get_bd__(True)

    clean_test_dataset = torch.load(os.path.join(shared_base_path, "bd_test_dataset.pth"), weights_only=False)
    clean_test_dataset = BDWrapper(clean_test_dataset, bbox_return_format="pascal_voc", data_split="test")
    clean_test_dataset.__get_bd__(False)

    print(f"[LOAD DATASET] Loaded {len(bd_train_dataset)} training images")
    print(f"[LOAD DATASET] Loaded {len(bd_val_dataset)} bd validation images")
    print(f"[LOAD DATASET] Loaded {len(clean_val_dataset)} clean validation images")
    print(f"[LOAD DATASET] Loaded {len(bd_test_dataset)} bd test images")
    print(f"[LOAD DATASET] Loaded {len(clean_test_dataset)} clean test images")

    # Show some sample images
    print("[LOAD DATASET] Showing Train Samples")
    show_sample_images(bd_train_dataset, "sample_bd_train", num_samples=5)

    print("[LOAD DATASET] Showing Validation Samples (BD}")
    show_sample_images(bd_val_dataset, "sample_bd_val", num_samples=5, box_key='bbox')

    print("[LOAD DATASET] Showing Validation Samples (Clean)")
    show_sample_images(clean_val_dataset, "sample_clean_val", num_samples=5, box_key='bbox')

    print("[LOAD DATASET] Showing Test Samples (BD)")
    show_sample_images(bd_test_dataset, "sample_bd_test", num_samples=5, box_key='bbox')

    print("[LOAD DATASET] Showing Test Samples (Clean)")
    show_sample_images(clean_test_dataset, "sample_clean_test", num_samples=5, box_key='bbox')

if __name__ == "__main__":
    args = parse_args(argparse.ArgumentParser())
    load_dataset(args)
