import os
import torch
from src.datamodules.common import CommonDataset

class MOVi(CommonDataset):
    # all_variants = ['a', 'b', 'c', 'e']
    all_variants = ['a', 'c', 'e']

    def load_single_data(self, root, variant, split):
        # Load data tensor for a single variant.
        # The tensor is located at os.path.join(root, f'movi-{variant}', f'{split}.pt')
        return torch.load(os.path.join(root, f'movi-{variant}', f'{split}.pt'))

    def load_data(self, root, variant, split):
        variants = self.all_variants if variant == 'all' else [variant]
        # Load data tensor for all variants and cocatenate them along the first (batch) dimension.
        # Each tensor is located at os.path.join(root, f'movi-{variant}', f'{split}.pt')
        # Also want to handle an error if the file does not exist. Specifically test.pt may not exist. In that case, use validation.pt instead.
        data = []
        for variant in variants:
            try:
                data.append(self.load_single_data(root, variant, split))
            except FileNotFoundError:
                if split == 'test':
                    data.append(self.load_single_data(root, variant, 'validation'))
                else:
                    raise FileNotFoundError(f"File not found: {os.path.join(root, f'movi-{variant}', f'{split}.pt')}")
        
        data = torch.cat(data, dim=0)

        # If split is 'test', shuffle the data along the batch dim.
        if split == 'test':
            data = data[torch.randperm(data.size(0))]
        return data