from ...bd_dataset.wrapper import BDWrapper
from ...bd_dataset.evaluator import COCOEvaluator
from ...bd_dataset.evaluator import MTSDEvaluator
from ...bd_dataset.rma.rma_evaluator import RMAEvaluator
from ...bd_dataset.oda.oda_evaluator import ODAEvaluator

import os
import torch

# --------------------------- globals ---------------------------------
DATASET_EVAL_CLS = {
    "coco": COCOEvaluator,
    "gtsdb": COCOEvaluator,
    "mtsd": MTSDEvaluator,
    "mtsd_meta": MTSDEvaluator,
}

ATTACK_SCHEMES_COCO = {
    "baseline": (None, False, []),
    "oda_untarget": (ODAEvaluator, True, [
        ("oda_untarget_single", "oda_untarget_single", False),
        ("oda_untarget_multi",  "oda_untarget_multi",  True),
    ]),
    "oda_target": (ODAEvaluator, False, [
        ("oda_target_single", "oda_target_single", False),
        ("oda_target_multi",  "oda_target_multi",  True),
    ]),
    "rma": (RMAEvaluator, False, [
        ("rma_single", "rma_single", False),
        ("rma_multi",  "rma_multi",  True),
    ]),
    "oda_align": (ODAEvaluator, False, [
        # align fixed/random
        ("oda_align_fixed_single",  "oda_align_fixed_single",  False),
        ("oda_align_fixed_multi",   "oda_align_fixed_multi",   True),
        # add untargeted evaluators too
        ("oda_untarget_single", "oda_untarget_single", False),
        ("oda_untarget_multi",  "oda_untarget_multi",  True),
    ]),
}

ATTACK_SCHEMES_MTSD = {
    "baseline": (None, False, []),
    "oda_untarget": (ODAEvaluator, True, [
        ("oda_untarget_single", "oda_untarget_single", False),
    ]),
    "oda_target": (ODAEvaluator, False, [
        ("oda_target_single", "oda_target_single", False),
    ]),
    "rma": (RMAEvaluator, False, [
        ("rma_single", "rma_single", False),
    ]),
    "oda_align": (ODAEvaluator, False, [
        # align fixed/random
        ("oda_align_fixed_single",  "oda_align_fixed_single",  False),
        # add untargeted evaluators too
        ("oda_untarget_single", "oda_untarget_single", False),
    ]),
}


# --------------------------- collate functions ------------------------

def train_collate_fn(batch):
    train_img, train_target = [], []

    for img, target, _ in batch:
        train_img.append(img)
        train_target.append(target)

    return train_img, train_target

def test_collate_fn(batch):
    
    test_img, test_target, img_ids = [], [], []
    
    for img_list, target_list, img_id_list in batch:

        for img, target, img_id in zip(img_list, target_list, img_id_list):
            test_img.append(img)
            test_target.append(target)
            img_ids.append(img_id)

    return test_img, test_target, img_ids

# ---------------------------------------------------------------------

def convert_bbox_format(bbox_format):

    # Convert from coordinate format to the format used by the model
    if bbox_format == 'xywh':
        return 'coco'
    elif bbox_format == 'xyxy':
        return 'pascal_voc'
    elif bbox_format == 'cxcywh':
        return 'yolo'
    else:
        raise ValueError(f"Unsupported bounding box format: {bbox_format}. Supported formats are 'xywh', 'xyxy', and 'cxcywh'.")

def build_loader(shared_base_path, split_name, filename, bd_flag, transform_fn, args, is_test, distributed, rank, world_size):
    """
    Returns a DataLoader for one split (train/val/test) and one mode (clean or backdoor).
    """
    # 1) Load and wrap
    path = os.path.join(shared_base_path, filename)
    ds_wrapper = torch.load(path, weights_only=False)

    # 2) Convert bbox format
    bbox_fmt = convert_bbox_format(ds_wrapper.bbox_current_format)

    # 3) Get transform and final bbox format
    data_transform, bbox_return_fmt = transform_fn(bbox_fmt)

    # 4) Wrap in BDWrapper and set poison flag
    ds = BDWrapper(ds_wrapper, bbox_return_format=bbox_return_fmt, data_split=split_name, transform=data_transform)
    ds.__get_bd__(bd_flag)

    # 5) Build sampler (if distributed)
    sampler = None
    shuffle = (split_name == "train")
    drop_last = (split_name == "train")
    if distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last)
        sampler.set_epoch(0)

    # 6) Choose collate & batch_size
    if split_name == "train":
        batch_size = args.batch_size

        # If distributed, we need to adjust the batch size
        if distributed:
            batch_size = args.batch_size // world_size
            if batch_size == 0:
                raise ValueError(f"[Rank {rank}] [TrainModel] Batch size {args.batch_size} is too small for distributed training with world size {world_size}. Please increase the batch size.")

        collate = train_collate_fn
    else:
        batch_size = 1
        collate = test_collate_fn

    # 7) Build DataLoader
    if is_test:
        num_workers = 0  # No workers for test dataset
        persistent_workers = False  # No persistent workers for test dataset
        pin_memory = False  # No pin memory for test dataset
    else:
        num_workers = args.num_workers
        persistent_workers = True
        pin_memory = True  # Pin memory for training dataset

    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=(sampler is None and shuffle), sampler=sampler, num_workers=num_workers, collate_fn=collate, pin_memory=pin_memory, persistent_workers=persistent_workers)
    return loader


def help_build_loader(base, dataset, dir_split, subdir, *, bd, transform,
                      args, distributed, rank, world_size):
    """
    dir_split: folder name on disk ("train" or "test_val")
    subdir: e.g. "baseline/clean_val_dataset.pth" or ".../bd_test_dataset.pth"
    """
    # derive logical split for BDWrapper
    if dir_split == "train":
        logical_split = "train"
    else:  # dir_split == "test_val"
        if "val_dataset" in subdir:
            logical_split = "val"
        elif "test_dataset" in subdir:
            logical_split = "test"
        else:
            raise ValueError(f"Cannot infer split from subdir '{subdir}'.")

    file_name = os.path.join(base, dataset, dir_split, subdir)
    return build_loader(
        base,
        logical_split,          # what BDWrapper expects
        file_name,
        bd_flag=bd,
        transform_fn=transform,
        args=args,
        is_test=(logical_split != "train"),
        distributed=distributed,
        rank=rank,
        world_size=world_size,
    )

def build_attack_evaluators(args, model_name, model, model_wrapper, device, distributed, rank, world_size):
    """
    Create every (loader, evaluator) tuple required by `args.test_attack`
    using ATTACK_SCHEMES.  Returns a list of {"type", "name", "loader", "evaluator"}.
    """
    if args.dataset == "coco":
        scheme, add_clean_multi, name_pairs = ATTACK_SCHEMES_COCO[args.test_attack]
    elif args.dataset in ("mtsd", "mtsd_meta"):
        scheme, add_clean_multi, name_pairs = ATTACK_SCHEMES_MTSD[args.test_attack]

    base, ds = args.bd_base_path, args.dataset
    trans = model_wrapper.transform_test
    ev_kwargs = dict(model_name=model_name, device=device, distributed=distributed, rank=rank, world_size=world_size)

    evaluators = []
    for name, path_prefix, is_multi in name_pairs:
        for split in ("val", "test"):
            if args.multi_position is not None:

                if not isinstance(args.multi_position, list) or len(args.multi_position) == 0:
                    raise ValueError(f"[Rank {rank}] [TrainModel] multi_position must be a non-empty list when specified, got {args.multi_position}")
                
                for pos in args.multi_position:
                    sub = f"{path_prefix}_{args.trigger_type}_{pos}/bd_{split}_dataset.pth"
                    loader = help_build_loader(base, ds, "test_val", sub, bd=True, transform=trans,
                                        args=args, distributed=distributed, rank=rank,
                                        world_size=world_size)
                    evaluator = scheme(model, is_test=(split == "test"), is_multi=is_multi, **ev_kwargs)
                    evaluators.append(dict(type=split, name=name, loader=loader, evaluator=evaluator, position=pos))
            else:
                sub = f"{path_prefix}_{args.trigger_type}_{args.trigger_position}/bd_{split}_dataset.pth"
                loader = help_build_loader(base, ds, "test_val", sub, bd=True, transform=trans,
                                        args=args, distributed=distributed, rank=rank,
                                        world_size=world_size)
                evaluator = scheme(model, is_test=(split == "test"), is_multi=is_multi, **ev_kwargs)
                evaluators.append(dict(type=split, name=name, loader=loader, evaluator=evaluator, position=args.trigger_position))

    # Optionally evaluate the *clean* multi‑object set with the same loader
    if add_clean_multi:
        for split in ("val", "test"):
            sub = "baseline/clean_" + ("val" if split == "val" else "test") + "_dataset.pth"
            loader = help_build_loader(base, ds, "test_val", sub, bd=False, transform=trans, args=args, distributed=distributed, rank=rank, world_size=world_size)

            clean_eval_cls = DATASET_EVAL_CLS[ds]
            evaluators.append(dict(type=split, name="clean_multi", loader=loader, evaluator=clean_eval_cls(model, device, distributed=distributed, rank=rank, world_size=world_size), position="none"))
    return evaluators


def initialize_loaders(args, model_wrapper, model_name, distributed=False, rank=0, world_size=1):
    """
    Build the training loader plus all validation / test evaluators
    required by the current experiment.
    """
    base, ds = args.bd_base_path, args.dataset
    model = model_wrapper.model.module if distributed else model_wrapper.model
    device = model_wrapper.device

    # ---------- training loader ----------
    if args.data_attack == "baseline":
        train_subdir = "baseline/clean_train_dataset.pth"
        print(f"[Rank {rank}] [TrainModel] Using clean training data.")
    else:
        if args.use_p_ratio:
            if args.p_ratio is None:
                raise ValueError("Poisoning ratio 'p_ratio' must be specified when 'use_p_ratio' is set.")
            
            print(f"[Rank {rank}] [TrainModel] Using poisoning ratio: {args.p_ratio} Dataset: {args.data_attack}, Trigger: {args.trigger_type}, Position: {args.trigger_position}")
            train_subdir = f"{args.data_attack}_{args.trigger_type}_{args.trigger_position}_{args.p_ratio}/bd_train_dataset.pth"
        else:
            train_subdir = f"{args.data_attack}_{args.trigger_type}_{args.trigger_position}/bd_train_dataset.pth"


    train_loader = help_build_loader(
        base, 
        ds, 
        "train", 
        train_subdir,
        bd=(args.data_attack != "baseline"),
        transform=model_wrapper.transform_train, 
        args=args,
        distributed=distributed, 
        rank=rank, 
        world_size=world_size,
    )

    # ---------- clean evaluator ----------
    clean_eval_cls = DATASET_EVAL_CLS[ds]
    clean_evaluator = clean_eval_cls(model, device, distributed=distributed, rank=rank, world_size=world_size)

    clean_val = help_build_loader(base, ds, "test_val", "baseline/clean_val_dataset.pth", bd=False, transform=model_wrapper.transform_test, args=args, distributed=distributed, rank=rank, world_size=world_size)
    clean_test = help_build_loader(base, ds, "test_val", "baseline/clean_test_dataset.pth", bd=False, transform=model_wrapper.transform_test, args=args, distributed=distributed, rank=rank, world_size=world_size)

    evaluators = [
        dict(type="val",  name="clean", loader=clean_val,  evaluator=clean_evaluator, position="none"),
        dict(type="test", name="clean", loader=clean_test, evaluator=clean_evaluator, position="none"),
    ]

    # ---------- attack‑specific evaluators ----------
    if args.test_attack != "baseline":
        evaluators.extend(
            build_attack_evaluators(args, model_name, model, model_wrapper, device, distributed=distributed, rank=rank, world_size=world_size)
        )

    print(f"[Rank {rank}] [TrainModel] Initialized {len(evaluators)} evaluators.")
    return train_loader, evaluators
