import argparse
import sys
import os
import torch
import torch.distributed as dist

# 1) The directory this file lives in:
here = os.path.dirname(__file__)  
project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
sys.path.append(project_root)

from BackdoorObjectDetection.bd_dataset.dataset.coco import COCOWrapper
from BackdoorObjectDetection.bd_dataset.dataset.mtsd import MTSDWrapper
from BackdoorObjectDetection.bd_dataset.dataset.mtsd_meta import MTSDMetaWrapper
from BackdoorObjectDetection.bd_dataset.dataset.gtsdb import GTSDBWrapper

from BackdoorObjectDetection.bd_dataset.rma.rma_test_dataset import RMATestWrapper
from BackdoorObjectDetection.bd_dataset.rma.rma_dataset import RMAWrapper
from BackdoorObjectDetection.bd_dataset.rma.rma_dataset_morph import RMAMorphWrapper

from BackdoorObjectDetection.bd_dataset.oda.oda_tba_dataset import ODATBAWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_uba_dataset import ODAUBAWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_uba_box_dataset import ODAUBABoxWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_suba_dataset import ODASUBAWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_test_dataset import ODATestWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_align_random_dataset import ODAAlignRandomWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_align_fixed_dataset import ODAAlignFixedWrapper
from BackdoorObjectDetection.bd_dataset.oda.oda_morph_dataset import ODAMorphWrapper

def parse_args(parser):

    # Attack arguments
    parser.add_argument("--attack", required=True, type=str, help="Attack name")
    parser.add_argument("--attack_config_path", required=True, type=str, help="Attack config path")
    parser.add_argument("--trigger_name", required=True, type=str, default="trigger.png", help="Trigger name")
    parser.add_argument("--trigger_position", required=True, type=str, default="center", help="Trigger position (center, low, high, both, random)")

    # Dataset arguments
    parser.add_argument("--dataset", required=True, type=str, help="Dataset name")
    parser.add_argument("--dataset_path", required=True, type=str, help="Dataset path")
    parser.add_argument("--bd_dataset_path", required=True, type=str, help="Backdoor dataset path")
    parser.add_argument("--save_dir", required=True, type=str, help="Directory to save the backdoor dataset")
    parser.add_argument("--is_test", action='store_true', help="If true, the dataset is for testing")

    args = parser.parse_args()
    return args

def load_dataset(args, split='train'):
    """
    Load the dataset based on the provided arguments.
    """
    if args.dataset == "coco":
        return COCOWrapper(args.dataset_path, data_split=split)
    elif args.dataset == "mtsd":
        return MTSDWrapper(args.dataset_path, data_split=split)
    elif args.dataset == "mtsd_meta":
        return MTSDMetaWrapper(args.dataset_path, data_split=split)
    elif args.dataset == "gtsdb":
        return GTSDBWrapper(args.dataset_path, data_split=split)
    else:
        raise ValueError("Dataset not supported")

def create_and_save_test_dataset(args):

    # Create the shared base path
    if args.attack != 'baseline':
        shared_base_path = os.path.join(args.bd_dataset_path, f"{args.dataset}", f'test_val', f'{args.save_dir}')
    else:
        shared_base_path = os.path.join(args.bd_dataset_path, f"{args.dataset}", f'test_val', 'baseline')

    if not os.path.exists(shared_base_path):
        print(f"[CREATE DATASET] The directory {shared_base_path} does not exist. Creating it...")
        os.makedirs(shared_base_path)
    
    else:
        raise ValueError(f"[CREATE DATASET] The directory {shared_base_path} already exists. Please remove it before creating a new dataset.")
    
    # Load the dataset
    val_dataset = load_dataset(args, split='val')
    test_dataset = load_dataset(args, split='test')

    # Create the backdoor dataset
    if args.attack.startswith("oda_"):

        # Valid attack types: oda_target_single, oda_target_multi, oda_untarget_single, oda_untargeted_multi
        if args.attack not in ['oda_target_single', 'oda_target_multi', 'oda_untarget_single', 'oda_untarget_multi', 'oda_align_fixed_multi', 'oda_align_fixed_single']:
            raise ValueError(f"[CREATE DATASET] Attack {args.attack} not supported for test dataset creation.")

        # If the attack is oda_align_fixed, we use the ODAAlignFixedWrapper
        if args.attack.startswith("oda_align_fixed"):
            print("[CREATE DATASET] Using ODA Align Fixed Wrapper for validation and test datasets...")
            bd_val_dataset = ODAAlignFixedWrapper(val_dataset, shared_base_path, args.attack_config_path, args.trigger_position, val_dataset.bbox_current_format, data_split="val")
            bd_test_dataset = ODAAlignFixedWrapper(test_dataset, shared_base_path, args.attack_config_path, args.trigger_position, test_dataset.bbox_current_format, data_split="test")
        
        else:
            print("[CREATE DATASET] Using ODA Test Wrapper for validation and test datasets...")
            bd_val_dataset = ODATestWrapper(val_dataset, shared_base_path, args.attack_config_path, args.trigger_position, val_dataset.bbox_current_format, data_split="val")
            bd_test_dataset = ODATestWrapper(test_dataset, shared_base_path, args.attack_config_path, args.trigger_position, test_dataset.bbox_current_format, data_split="test")
    
    elif args.attack.startswith("rma_"):

        # Valid attack types: rma_single, rma_multi
        if args.attack not in ['rma_single', 'rma_multi']:
            raise ValueError(f"[CREATE DATASET] Attack {args.attack} not supported for test dataset creation.")

        print("[CREATE DATASET] Using RMA Wrapper for validation and test datasets...")
        bd_val_dataset = RMATestWrapper(val_dataset, shared_base_path, args.attack_config_path, args.trigger_position, val_dataset.bbox_current_format, data_split="val")
        bd_test_dataset = RMATestWrapper(test_dataset, shared_base_path, args.attack_config_path, args.trigger_position, test_dataset.bbox_current_format, data_split="test")
    
    elif args.attack == 'baseline':
        print("[CREATE DATASET] Using baseline dataset for validation and test...")
        bd_val_dataset = val_dataset
        bd_test_dataset = test_dataset
    else:
        raise ValueError(f"[CREATE DATASET] Attack {args.attack} not supported for test dataset creation.")
    
    # Save dataset to disk
    print("[CREATE DATASET] Saving dataset to disk...")
    if args.attack != 'baseline':
        torch.save(bd_val_dataset, os.path.join(shared_base_path, "bd_val_dataset.pth"))
        torch.save(bd_test_dataset, os.path.join(shared_base_path, "bd_test_dataset.pth"))
    else:
        torch.save(bd_val_dataset, os.path.join(shared_base_path, "clean_val_dataset.pth"))
        torch.save(bd_test_dataset, os.path.join(shared_base_path, "clean_test_dataset.pth"))

def create_and_save_train_dataset(args):
    
    # Create the shared base path
    if args.attack != 'baseline':
        shared_base_path = os.path.join(args.bd_dataset_path, f"{args.dataset}", f'train', f'{args.save_dir}')
    else:
        shared_base_path = os.path.join(args.bd_dataset_path, f"{args.dataset}", f'train', 'baseline')

    if not os.path.exists(shared_base_path):
        print(f"[CREATE DATASET] The directory {shared_base_path} does not exist. Creating it...")
        os.makedirs(shared_base_path)
    
    else:
        raise ValueError(f"[CREATE DATASET] The directory {shared_base_path} already exists. Please remove it before creating a new dataset.")
    
    # Load the dataset
    train_dataset = load_dataset(args, split='train')

    # Create the backdoor dataset
    # Valid attack types: oda_align_random, oda_align_fixed, oda_morph, oda_uba, oda_tba
    if args.attack == 'oda_align_random':
        print("[CREATE DATASET] Using ODA Align Random Wrapper for training dataset...")
        bd_train_dataset = ODAAlignRandomWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'oda_align_fixed':
        print("[CREATE DATASET] Using ODA Align Fixed Wrapper for training dataset...")
        bd_train_dataset = ODAAlignFixedWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'oda_morph':

        if args.dataset != 'mtsd' and args.dataset != 'mtsd_meta':
            raise ValueError(f"[CREATE DATASET] ODA Morph attack is only supported for MTSD and MTSD Meta datasets.")
        
        class_ids = train_dataset.class_ids
        is_meta = False if args.dataset == 'mtsd' else True
        print("[CREATE DATASET] Using ODA Morph Wrapper for training dataset...")
        bd_train_dataset = ODAMorphWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, class_ids, is_meta=is_meta, bbox_current_format=train_dataset.bbox_current_format, data_split="train")

    elif args.attack == 'oda_uba':
        print("[CREATE DATASET] Using ODA UBA Wrapper for training dataset...")
        bd_train_dataset = ODAUBAWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'oda_uba_box':
        print("[CREATE DATASET] Using ODA UBA Box Wrapper for training dataset...")
        bd_train_dataset = ODAUBABoxWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'oda_tba':
        print("[CREATE DATASET] Using ODA TBA Wrapper for training dataset...")
        bd_train_dataset = ODATBAWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'oda_suba':
        print("[CREATE DATASET] Using ODA SUBA Wrapper for training dataset...")
        bd_train_dataset = ODASUBAWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'rma':
        print("[CREATE DATASET] Using RMA Wrapper for training dataset...")
        bd_train_dataset = RMAWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, train_dataset.bbox_current_format, data_split="train")
    elif args.attack == 'rma_morph':

        if args.dataset != 'mtsd' and args.dataset != 'mtsd_meta':
            raise ValueError(f"[CREATE DATASET] RMA Morph attack is only supported for MTSD and MTSD Meta datasets.")
        
        class_ids = train_dataset.class_ids
        is_meta = False if args.dataset == 'mtsd' else True

        print("[CREATE DATASET] Using RMA Morph Wrapper for training dataset...")
        bd_train_dataset = RMAMorphWrapper(train_dataset, shared_base_path, args.attack_config_path, args.trigger_position, class_ids, is_meta=is_meta, bbox_current_format=train_dataset.bbox_current_format, data_split="train")
        
    elif args.attack == 'baseline':
        print("[CREATE DATASET] Using baseline dataset for training...")
        bd_train_dataset = train_dataset
    else:
        raise ValueError(f"[CREATE DATASET] Attack {args.attack} not supported for training dataset creation.")

    # Save dataset to disk
    print("[CREATE DATASET] Saving dataset to disk...")
    if args.attack != 'baseline':
        torch.save(bd_train_dataset, os.path.join(shared_base_path, "bd_train_dataset.pth"))
    else:
        torch.save(bd_train_dataset, os.path.join(shared_base_path, "clean_train_dataset.pth"))
        
if __name__ == "__main__":
    args = parse_args(argparse.ArgumentParser())
    
    # If is_test is True, create the test dataset
    if args.is_test:
        print("[CREATE DATASET] Creating test dataset...")
        create_and_save_test_dataset(args)
    else:
        print("[CREATE DATASET] Creating training dataset...")
        create_and_save_train_dataset(args)

