import torch
import torch.distributed as dist
import sys
import os

# 1) The directory this file lives in:
# here = os.path.dirname(__file__)  
# project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
# sys.path.append(project_root)

from ...bd_models.models.build import build_model
from ...bd_models.utils.load_utils import initialize_loaders
from ...bd_dataset.wrapper import DefenseWrapper

# from BackdoorObjectDetection.bd_dataset.wrapper import DefenseWrapper
# from BackdoorObjectDetection.bd_models.models.build import build_model
# from BackdoorObjectDetection.bd_models.utils.load_utils import initialize_loaders

def get_defense_loader_model(args, device, distributed=False, rank=0, local_rank=0, world_size=1):
    """
    Initialize the data loaders and model for fine-tuning defense.

    Args:
        args: Argument parser containing all necessary configurations.
        distributed (bool): Whether to use distributed training.
        rank (int): Rank of the current process in distributed training.
        local_rank (int): Local rank of the current process.
        world_size (int): Total number of processes in distributed training.

    Returns:
        train_loader: DataLoader for training data.
        evaluators: List of evaluators for validation/testing.
        model_wrapper: The model wrapped for training.
    """
    # Step 1: Build the model
    print(f'[Rank {rank}] [TrainModel] Building model: {args.model}')
    model_wrapper = build_model(args.model, args.dataset, args.model_config_path, device, args.save_path, distributed=distributed, local_rank=local_rank)

    # Step 2: Update the model weights to the weights in the original saved checkpoint
    weight_path = os.path.join(args.record_path, args.save_dir, 'checkpoint.pth')
    if not os.path.exists(weight_path):
        raise FileNotFoundError(f"[Rank {rank}] [TrainModel] Checkpoint file not found at {weight_path}")
    
    print(f'[Rank {rank}] [TrainModel] Loading model weights from: {weight_path}')
    checkpoint = torch.load(weight_path, map_location=device, weights_only=False)

    # Check if checkpoint contains 'model' or 'model_state' keys
    if 'model' not in checkpoint and 'model_state' not in checkpoint:
        raise KeyError(f"[Rank {rank}] [TrainModel] Checkpoint does not contain 'model' or 'model_state' keys.")

    if distributed:
        if args.model == 'yolo':
            # For YOLO models in DDP, the model is already wrapped in DistributedDataParallel
            state_dict = checkpoint['model_state']
            model_wrapper.model.module.model.load_state_dict(state_dict)
        else:
            state_dict = checkpoint['model']
            model_wrapper.model.module.load_state_dict(state_dict)
    else:
        if args.model == 'yolo':
            state_dict = checkpoint['model_state']
            model_wrapper.model.model.load_state_dict(state_dict)
        else:
            state_dict = checkpoint['model']
            model_wrapper.model.load_state_dict(state_dict)

    # Use a cosine learning rate scheduler
    model_wrapper.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_wrapper.optimizer, T_max=args.training_epochs)
    model_wrapper.current_epoch = 0
    model_wrapper.epochs = args.training_epochs
    
    print(f'[Rank {rank}] [TrainModel] Model and optimizer initialized.')

    # Step 3: Initialize data loaders
    train_loader, evaluators = initialize_loaders(
        args, model_wrapper, args.model, distributed=distributed, rank=rank, world_size=world_size
    )

    print(f'[Rank {rank}] [TrainModel] Dataset size: {len(train_loader.dataset)} samples')
    print(f'[Rank {rank}] [TrainModel] Number of batches in training loader: {len(train_loader)}')

    # Step 5: Wrap the dataset with DefenseWrapper to ensure clean samples only
    print(f'[Rank {rank}] [TrainModel] Wrapping training dataset to ensure clean samples only')
    old_loader = train_loader
    wrapped_dataset = DefenseWrapper(old_loader.dataset, args.num_training_samples, random_seed=args.random_seed)

    old_sampler = old_loader.sampler
    if old_sampler is not None and distributed:
        # If using distributed sampler, create a new one for the wrapped dataset
        sampler = torch.utils.data.distributed.DistributedSampler(
            wrapped_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            seed=args.random_seed
        )
    else:
        sampler = None

    train_loader = torch.utils.data.DataLoader(
        wrapped_dataset,
        batch_size=old_loader.batch_size,
        sampler=sampler,
        shuffle=(sampler is None and True),
        num_workers=old_loader.num_workers,
        pin_memory=old_loader.pin_memory,
        drop_last=old_loader.drop_last,
        collate_fn=old_loader.collate_fn,
    )

    print(f'[Rank {rank}] [TrainModel] Subsampled training dataset size: {len(train_loader.dataset)} samples')
    print(f'[Rank {rank}] [TrainModel] Number of batches in training loader: {len(train_loader)}')
    print(f'[Rank {rank}] [TrainModel] Batch size: {train_loader.batch_size}')

    if distributed:
        dist.barrier()

    return train_loader, evaluators, model_wrapper
