import pyrootutils
root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git"],
    pythonpath=True,
    dotenv=True,
)

import random
import numpy as np
import os
import torch
from glob import glob
import pfio


class AllProcGen(torch.utils.data.Dataset):
    name = 'allprocgen'
    max_len = 64

    def __init__(self, split, seq_len, root, unseen=[], unplayed=[], env=None,
                 transform=None, variant=None, scs_url=None, single_sample=False, fix_start=False):
        # Get the list of all npz files (created by make_allprocgen.py) 
        # under '{root}/allprocgen/{split}/{variant}/{i}.npz' 

        self.seq_len = seq_len
        root = os.path.join(root, self.name, split if split in ['train', 'eval'] else 'test')
        self.fs = pfio.v2.from_url(root, http_cache=scs_url)
        self.transform = transform
        self.fix_start = fix_start

        variants = [os.path.basename(p) for p in glob(os.path.join(root, '*')) if os.path.isdir(p)]
        variants = sorted(variants)
        self.variants = variants

        # if split is train or test, exclude ood variants
        if split == 'unseen':
            variants = unseen
        elif split == 'unplayed':
            variants = unplayed
        elif split in ['train', 'test', 'eval']:
            if env is not None:
                if env not in variants:
                    raise ValueError(f"Invalid env: {env}")
                variants = [env]
            else:
                variants = set(variants) - set(unseen) - set(unplayed)
        else:
            raise ValueError(f"Invalid split: {split}")

        self.all_files = []

        # Note that root may contain some files that are not directories. Ignore them.
        for variant in variants:
            files = glob(os.path.join(root, variant, '*.npz'))
            files = sorted(files)
            if single_sample:
                # Randomly select one file from each variant
                self.all_files.append(random.choice(files))
            else:
                self.all_files.extend(files)
        
        # shuffle files if split is 'test' or 'ood'
        if split in ['test', 'unseen', 'unplayed']:
            random.shuffle(self.all_files)


    def __len__(self):
        return len(self.all_files)
    
    def __getitem__(self, index):
        with self.fs.open(self.all_files[index]) as fp:
            arrays = np.load(fp)
        
        video = torch.from_numpy(arrays['video']).float()
        action = torch.from_numpy(arrays['action'])

        if self.seq_len > 0:
            max_seq_len = video.shape[0]
            start = 0 if self.fix_start else random.randrange(max_seq_len - self.seq_len + 1)
            end = start + self.seq_len
            video = video[start:end]
            action = action[start:end]

        # apply user defined transform (e.g. normalization)
        if self.transform:
            video = self.transform(video)

        # return video, action, index
        variant = os.path.basename(os.path.dirname(self.all_files[index]))
        #return video, action, variant
        return video, action, index