from torch.utils.data import Dataset
from typing import *
from utils.tasks import CausalQATask

# You shouhld implement your own dataset following the structure of dataset base

class example_dataset_base(Dataset):
    def __init__(
        self,
        dir: Optional[str] = None,
        d_max: int = -1, # Max dimension for A matrix
        split: Optional[str] = None, # 'split' can be 'train', 'eval', or 'test', None means all data
        shuffle: bool = True,
        length_limit: Optional[int] = None,
        exp_tasks: Optional[List[CausalQATask]] = None, # Prepare different data for different tasks
    ):
        # Load your data here
        pass

    def __len__(self):
        # Return the length of your dataset
        pass

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        # Return a single data point as a dictionary
        sample_return_table = {
            'id': "id",
            'images': [], # List of PIL images
            'prompt': "prompt",
            'A_star': None, # Padded causal adjacency matrix as torch.Tensor of shape (d_max, d_max)
            'A_mask': None, # Binary mask indicating valid entries in the padded matrix
            'answer': "answer",
            'label': "label", 
        }

        if CausalQATask.ALIGNMENT in self.exp_tasks: 
            # If ALIGNMENT task is included, provide extra fields
            extra_for_alignment = {
                'node_prompt_emb': None, # Node-level text embeddings for each node in the causal graph
                'explanation_prompt_emb': None, # Explanation-level text embeddings for the entire causal graph (or edges)
            }
        pass


        