import pyrootutils
root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git"],
    pythonpath=True,
    dotenv=True,
)

import numpy as np
import os
import torch
import torchvision.transforms.v2.functional as TF
from src.datamodules.common import CommonDataset

class ProcGen(CommonDataset):
    name = 'procgen'
    min_len = 64
    all_variants = [
        "bigfish",
        "bossfight",
        "caveflyer",
        # "chaser",
        # "climber",
        "coinrun",
        # "dodgeball",
        # "fruitbot",
        # "heist",
        # "jumper",
        # "leaper",
        # "maze",
        # "miner",
        # "ninja",
        # "plunder",
        # "starpilot",
    ]

    @property
    def max_len(self):
        return self.min_len

    def _load_data(self, root, variant, split):
        # The data is stored in a numpy file located at os.path.join(root, f'{variant}/{split}/{chunk_id}.npz').
        # So first we need to get the list of all chunk_ids, then load the data from each chunk and concatenate them.
        # Each npz file contains the following keys of the shape:
        #    obs (seq_len, 64, 64, 3)
        #    ta (seq_len,)
        #    log_probs (seq_len,)
        #    done (seq_len,)
        #    rewards (seq_len,)
        #    ep_returns (seq_len,)
        #    values (seq_len,)
        # Here ta contains action labels and done indicates the end of an episode.
        # What we need to do is the following:
        # 1. Load the obs and ta from each chunk and concatenate them as torch.Tensor.
        # 2. Change axis of obs from (seq_len, 64, 64, 3) to (seq_len, 3, 64, 64).
        # 3. Split them by doen to get the list of sequences.
        # 4. Return the list of sequences.

        # Get the list of all chunk_ids
        data_path = os.path.join(root, self.name, variant, split)
        chunk_ids = os.listdir(data_path)
        
        # Load the obs and ta from each chunk and concatenate them as torch.Tensor
        videos = []
        actions = []
        dones = []
        for chunk_id in chunk_ids:
            file_path = os.path.join(data_path, chunk_id)
            data = np.load(file_path)
            
            obs = torch.from_numpy(data['obs'])
            videos.append(obs)
            
            ta = torch.from_numpy(data['ta'])
            # ta = data['ta']
            actions.append(ta)

            done = torch.from_numpy(data['done'])
            # done = data['done']
            dones.append(done)

        # Concat videos, actions, and dones
        videos = torch.cat(videos, dim=0).permute(0, 3, 1, 2)
        actions = torch.cat(actions, dim=0)
        dones = torch.cat(dones, dim=0)

        # Split videos and actions by dones to get the episode-wise sequences
        # Use torch.split to split videos and actions by dones --- dones[t] == True if the t-th frame is the end of an episode
        video_list = []
        action_list = []
        start = 0
        for end in torch.where(dones)[0] + 1:
            if end - start < self.min_len:
                continue
            video_list.append(videos[start:end])
            action_list.append(actions[start:end])
            start = end
        
        return video_list, action_list

    def load_data(self, root, variant, split):
        variants = self.all_variants if variant == 'all' else [variant]
        videos = []
        actions = []
        for variant in variants:
            _videos, _actions = self._load_data(root, variant, split)
            videos.extend(_videos)
            actions.extend(_actions)
        
        # shuffle actions and videos if split is 'test'
        if split == 'test':
            indices = torch.randperm(len(videos))
            videos = [videos[i] for i in indices]
            actions = [actions[i] for i in indices]
        
        self.action_list = actions
        return videos

    def __getitem__(self, index):
        videos, start, index = super().__getitem__(index)
        actions = self.action_list[index].roll(-start, dims=0)[:self.seq_len]
        return videos, actions, index


# check that the data loader works
if __name__ == "__main__":
    dataset = ProcGen(split="train", seq_len=32, variant='bigfish')
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
    for videos, actions, index in dataloader:
        print(videos.shape, actions.shape, index)
        break
    