# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
import torch
import warnings
from dataclasses import dataclass
from typing import Any, List

from megatron.energon import Sample
from megatron.energon.epathlib.epath import EPath
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory


@dataclass
class SampleListSample(Sample):
    """Sample type for a list of samples of any type which needs to be packed together.

    This is useful for datasets which are packed offline.
    """

    #: The images of the sequence
    samples: List[Any]


class SampleListWebdataset(DefaultDecoderWebdatasetFactory[SampleListSample]):
    __sample_type__ = SampleListSample

    def __init__(self, path: EPath, **kwargs):
        warnings.warn(
            f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
            f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
            f"# remove top-level __module__ and __class__\n"
            f"sample_type:\n"
            f"  __module__: megatron.energon\n"
            f"  __class__: {self.__sample_type__.__name__}\n"
            f"# Keep the remaining content",
            DeprecationWarning,
        )
        super().__init__(path, **kwargs)


@dataclass
class OfflineTargetAspectRatioSample(Sample):
    """Sample type for image + text samples with target aspect ratio computed offline."""

    #: The images of the sequence
    images: List[torch.Tensor]
    #: The texts of the sequence
    texts: List[str]
    target_aspect_ratio: List[List]

    if not hasattr(Sample, "__subflavor__"):
        __subflavor__: str
