import torch
from dataclasses import dataclass
from typing import Optional
import numpy as np


@dataclass
class FrameData:
    id: str
    x0: torch.Tensor
    xt: torch.Tensor
    lag: int
    temp: int
    mask: Optional[torch.Tensor] = None
    residue_ids: Optional[torch.Tensor] = None
    sequence_emb: Optional[torch.Tensor] = None
    chain_breaks_per_residue: Optional[torch.Tensor] = None
    residue_pdb_idx: Optional[torch.Tensor] = None
    cath_code: Optional[torch.Tensor] = None
    deepseek_classification: Optional[torch.Tensor] = None
    deepseek_confidence: Optional[torch.Tensor] = None


class FrameDataCollator:
    def __init__(self, pad_to=None):
        """
        Collator for FrameData objects.
        Args:
            pad_to: int or None. If set, pad to this length. If None, pad to
                max length in batch.
        """
        self.pad_to = pad_to

    def __call__(self, batch):
        """
        Pads all per-atom/residue tensors in the batch to the same length
        (max in batch or pad_to), and returns a mask indicating non-padded
        regions.
        Args:
            batch: list of FrameData objects
        Returns:
            dict with padded tensors and a 'mask' key (bool tensor, True for
            real, False for pad).
        """
        lengths = [item.x0.shape[0] for item in batch]
        max_len = max(lengths) if self.pad_to is None else self.pad_to

        def pad_tensor(t, pad_value=0):
            if t is None:
                return None
            pad_size = [max_len - t.shape[0]] + list(t.shape[1:])
            if pad_size[0] == 0:
                return t
            padding = torch.full(pad_size, pad_value, dtype=t.dtype, device=t.device)
            return torch.cat([t, padding], dim=0)

        tensor_fields = [
            "x0",
            "xt",
            "residue_ids",
            "sequence_emb",
            "chain_breaks_per_residue",
            "residue_pdb_idx",
        ]
        scalar_fields = [
            "lag",
            "temp",
            "deepseek_classification",
            "deepseek_confidence",
        ]
        list_fields = [
            "id",
            "cath_code",
        ]

        out = {}
        for field in tensor_fields:
            values = [getattr(item, field) for item in batch]
            if all(v is not None for v in values):
                out[field] = torch.stack([pad_tensor(v) for v in values])

        for field in scalar_fields:
            values = [getattr(item, field) for item in batch]
            if all(v is not None for v in values):
                out[field] = torch.tensor(values)

        for field in list_fields:
            values = [getattr(item, field) for item in batch]
            if all(v is not None for v in values):
                out[field] = np.array(values)

        mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
        for i, l in enumerate(lengths):
            mask[i, :l] = True
        out["mask"] = mask

        return out
