import numpy as np
import joblib
from tqdm import tqdm


class CalvinEpisodeContainer:
    def __init__(self, data_pths, modularities=('observations', 'actions')):
        self.data = dict()
        for key in modularities:
            self.data[key] = []
        self.epi_lengths = []
        ptr = 0
        for pth in tqdm(data_pths):
            episode = joblib.load(open(pth, 'rb'))
            epi_length = len(episode[modularities[0]])
            self.epi_lengths.append(epi_length)
            for key in modularities:
                self.data[key].append(episode[key])
            ptr += epi_length

    def __getitem__(self, idx):
        epi_idx, timestep = idx
        return {k: self.data[k][epi_idx][timestep] for k in self.data.keys()}


class D4rlEpisodeContainer:
    def __init__(self, env=None, dataset=None, modularities=('observations', 'actions')):
        self.data = dict()
        for key in modularities:
            self.data[key] = []
        self.epi_lengths = []

        if dataset is None:
            dataset = env.get_dataset()
        dataset['timeouts'][-1] = True
        epi_ends = np.where(dataset['timeouts'])[0]
        epi_start = 0
        for epi_end in epi_ends:
            for key in modularities:
                self.data[key].append(dataset[key][epi_start:epi_end + 1])
            epi_length = epi_end + 1 - epi_start
            self.epi_lengths.append(epi_length)
            epi_start = epi_start + epi_length

    def __getitem__(self, idx):
        epi_idx, timestep = idx
        return {k: self.data[k][epi_idx][timestep] for k in self.data.keys()}

