import torch
from torch.nn.utils.rnn import pad_sequence


def data_collator(batch):
    """
    Custom data collator to handle the pose data in a batch, ensuring proper
    tensor conversion and padding where necessary.

    Args:
        batch (list): A batch of samples where each sample is a tuple
                      (pose, length, ignore_mask, video_size, bodies_subset).

    Returns:
        dict: A dictionary of collated data including poses, lengths, ignore_mask,
              video_size, and bodies_subset.
    """
    # Filter out None entries from the batch
    batch = [b for b in batch if b is not None]

    if len(batch) == 0:
        return None

    # Convert batch elements to appropriate tensors or lists
    poses = [torch.tensor(b["feature"], dtype=torch.float32) for b in batch]
    lengths = [b["length"] for b in batch]

    if "text" in batch[0]:
        captions = [b["text"] for b in batch]
    else:
        captions = ["" for b in batch]

    # Collate tensors by stacking them along a new dimension (for batching)
    adapted_batch = {
        "poses": pad_sequence(poses, batch_first=True),
        "lengths": lengths,
        "captions": captions,
    }

    if "tasks" in batch[0]:
        adapted_batch["tasks"] = [b["tasks"] for b in batch]

    if "keypoint_scores" in batch[0]:
        keypoint_scores = [
            torch.tensor(b["keypoint_scores"], dtype=torch.float32) for b in batch
        ]
        adapted_batch["keypoint_scores"] = pad_sequence(
            keypoint_scores, batch_first=True
        )

    if "video_path" in batch[0]:
        video_paths = [b["video_path"] for b in batch]
        adapted_batch["video_path"] = video_paths

    if "audio_path" in batch[0]:
        audio_paths = [b["audio_path"] for b in batch]
        adapted_batch["audio_path"] = audio_paths

    if "pose_path" in batch[0]:
        pose_paths = [b["pose_path"] for b in batch]
        adapted_batch["pose_path"] = pose_paths

    if "audio_data" in batch[0]:
        audio_data = [b["audio_data"] for b in batch]
        adapted_batch["audio_data"] = audio_data

    if "audio_tokens" in batch[0]:
        audio_tokens = pad_sequence(
            [torch.tensor(b["audio_tokens"], dtype=torch.float32) for b in batch],
            batch_first=True,
        )
        adapted_batch["audio_tokens"] = audio_tokens

    if "audio_lengths" in batch[0]:
        audio_lengths = [b["audio_lengths"] for b in batch]
        adapted_batch["audio_lengths"] = audio_lengths

    if "hidden_states_path" in batch[0]:
        hidden_states_paths = [b["hidden_states_path"] for b in batch]
        adapted_batch["hidden_states_paths"] = hidden_states_paths

    return adapted_batch


def data_collator_token(batch):
    # Filter out None entries from the batch
    batch = [b for b in batch if b is not None]

    if len(batch) == 0:
        return None

    adapted_batch = {}
    # Convert batch elements to appropriate tensors or lists
    if "feature" in batch[0]:
        poses = [torch.tensor(b["feature"], dtype=torch.float32) for b in batch]
        adapted_batch["poses"] = pad_sequence(poses, batch_first=True).float()
    else:
        adapted_batch["poses"] = None

    if "hidden_states" in batch[0]:
        hidden_states = [b["hidden_states"] for b in batch]
        adapted_batch["hidden_states"] = pad_sequence(hidden_states, batch_first=True)

    if "length" in batch[0]:
        lengths = [b["length"] for b in batch]
        adapted_batch["lengths_poses"] = lengths
        adapted_batch["lengths"] = lengths
    else:
        adapted_batch["lengths_poses"] = None

    if "text" in batch[0]:
        captions = [b["text"] for b in batch]
    else:
        captions = ["" for b in batch]
    adapted_batch["captions"] = captions

    if "tasks" in batch[0]:
        adapted_batch["tasks"] = [b["tasks"] for b in batch]

    if "keypoint_scores" in batch[0]:
        keypoint_scores = [
            torch.tensor(b["keypoint_scores"], dtype=torch.float32) for b in batch
        ]
        adapted_batch["keypoint_scores"] = pad_sequence(
            keypoint_scores, batch_first=True
        )

    if "ignore_mask" in batch[0]:
        ignore_mask = [torch.tensor(b["ignore_mask"], dtype=torch.bool) for b in batch]
        adapted_batch["ignore_mask"] = pad_sequence(ignore_mask, batch_first=True)

    if "video_size" in batch[0]:
        video_size = [b["video_size"] for b in batch]
        adapted_batch["video_size"] = video_size

    if "bodies_subset" in batch[0]:
        bodies_subset = [
            torch.tensor(b["bodies_subset"], dtype=torch.long) for b in batch
        ]
        adapted_batch["bodies_subset"] = pad_sequence(bodies_subset, batch_first=True)

    if "video_path" in batch[0]:
        video_paths = [b["video_path"] for b in batch]
        adapted_batch["video_path"] = video_paths

    if "audio_path" in batch[0]:
        audio_paths = [b["audio_path"] for b in batch]
        adapted_batch["audio_path"] = audio_paths

    if "motion_tokens" in batch[0]:
        motion_tokens = pad_sequence(
            [torch.tensor(b["motion_tokens"], dtype=torch.long) for b in batch],
            batch_first=True,
        )
        adapted_batch["motion_tokens"] = motion_tokens

    if "mot_token_lengths" in batch[0]:
        mot_token_lengths = [b["mot_token_lengths"] for b in batch]
        adapted_batch["mot_token_lengths"] = mot_token_lengths

    if "audio_tokens" in batch[0]:
        audio_tokens = pad_sequence(
            [torch.tensor(b["audio_tokens"], dtype=torch.long) for b in batch],
            batch_first=True,
        )
        adapted_batch["audio_tokens"] = audio_tokens

    if "audio_byte" in batch[0]:
        audio_bytes = [b["audio_byte"] for b in batch]
        adapted_batch["audio_bytes"] = audio_bytes

    if "motion_tokens" in batch[0]:
        motion_tokens = pad_sequence(
            [torch.tensor(b["motion_tokens"], dtype=torch.long) for b in batch],
            batch_first=True,
        )
        adapted_batch["motion_tokens"] = motion_tokens

    if "audio_lengths" in batch[0]:
        audio_lengths = [b["audio_lengths"] for b in batch]
        adapted_batch["audio_lengths"] = audio_lengths

    if "instruction" in batch[0]:
        instructions = [b["instruction"] for b in batch]
        adapted_batch["instructions"] = instructions

    if "response" in batch[0]:
        responses = [b["response"] for b in batch]
        adapted_batch["responses"] = responses

    if "hidden_states_path" in batch[0]:
        hidden_states_paths = [b["hidden_states_path"] for b in batch]
        adapted_batch["hidden_states_paths"] = hidden_states_paths

    return adapted_batch


def data_collator_seed_eval(batch):
    # Filter out None entries from the batch
    batch = [b for b in batch if b is not None]

    if len(batch) == 0:
        return None

    adapted_batch = {}

    # Convert batch elements to appropriate tensors or lists
    adapted_batch["utts"] = [b["utt"] for b in batch]
    adapted_batch["texts"] = [b["text"] for b in batch]
    adapted_batch["prompt_texts"] = [b["prompt_text"] for b in batch]
    adapted_batch["audio_paths"] = [b["audio_path"] for b in batch]
    adapted_batch["prompt_audio_paths"] = [b["prompt_audio_path"] for b in batch]

    return adapted_batch
