import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class PytorchFullEpisodeDataset(Dataset):
    def __init__(
        self,
        pytorch_episodes_db,
        cache_size=100000,
        episode_relabellers=[]
    ):
        self._episodes_db = pytorch_episodes_db
        self._index = []
        self._cache = {}
        self._idx_cache = []
        self._cache_size = cache_size
        self._episode_relabellers = episode_relabellers
        self._run_index()

    def _get_episode(self, idx):
        if idx in self._cache:        
            return self._cache[idx]
        else:
            e = self._episodes_db[idx]
            e = {k: torch.tensor(v) for k,v in e.items()}
            for _r in self._episode_relabellers:
                e = _r(e, idx, None, None, None)    
            self._cache[idx] = {k: v for k,v in e.items()}
            self._idx_cache.append(idx)
            while len(self._idx_cache) > self._cache_size:
                idx = self._idx_cache.pop(0)
                del self._cache[idx]
        return self._cache[idx]

    @torch.no_grad()
    def _run_index(self):
        self._index = [i for i in self._episodes_db.get_ids()]        

    @torch.no_grad()
    def __getitem__(self, idx):
        _idx = self._index[idx]
        r = self._get_episode(_idx)    
        return {k: v for k,v in r.items()}

    def __len__(self):
        return len(self._index)
    
class _collate_fn_full_episode:
    def __init__(self,max_episode_size=1000):
        self._fk=None
        self._padding={}
        self._true=None
        self._false=None
        self._max_episode_size=max_episode_size
    
    def _prepare_padding(self,episode):
        self._fk=next(episode.__iter__())
        device=episode[self._fk].device
        for k, v in episode.items():
            s=list(v.size()[1:])
            self._padding[k]=torch.zeros(self._max_episode_size,*s,dtype=v.dtype,device=device)
        self._true=torch.ones(self._max_episode_size,device=device).bool()
        self._false=torch.zeros(self._max_episode_size,device=device).bool()
        
    def __call__(self,episodes):
        max_T=0

        for e in episodes:
            if self._fk is None: self._prepare_padding(e)
            T=e[self._fk].size()[0]
            max_T=max(max_T,T)
            assert T<=self._max_episode_size,"Episode is larger than the max episode size of the collate function: "+str(T)+" vs "+str(self._max_episode_size)  
        nepisodes=[]
        for e in episodes:            
            T=e[self._fk].size()[0]
            rT=max_T-T
            if rT>0:
                ne={}
                for k,v in e.items():
                    ne[k]=torch.cat([v,self._padding[k][:rT]],dim=0).unsqueeze(0)
                ne["_is_padding"]=torch.cat([self._false[:T],self._true[:rT]],dim=0).unsqueeze(0)
                ne['timesteps'] = torch.cat([torch.arange(0, T),torch.zeros(rT)],dim=0).int().unsqueeze(0)
                nepisodes.append(ne)
            else:
                nepisodes.append({k:v.unsqueeze(0) for k,v in e.items()})
                nepisodes[-1]["_is_padding"]=self._false[:T].unsqueeze(0)
                nepisodes[-1]['timesteps']=torch.arange(0, T).int().unsqueeze(0)

        results={}
        for k in nepisodes[0]:
            t=[e[k] for e in nepisodes]
            v=torch.cat(t, dim=0)
            results[k]=v

        return results

class FullEpisodeLoader(DataLoader):
    def __init__(self, dataset, max_episode_size, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, persistent_workers=False):
        super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
           batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=_collate_fn_full_episode(max_episode_size),
           pin_memory=pin_memory, drop_last=drop_last, timeout=timeout,
           worker_init_fn=worker_init_fn,
           persistent_workers=persistent_workers)

class ContrastiveTripletsDataset(Dataset):
    def __init__(
        self,
        pytorch_episodes_db,
        window,
        cache_size=100000,
        device="cpu",
        d4rl=True,
        keys_to_tokenize=['obs/pos'],
        is_godot=True
    ):
        self._device=device
        self._episodes_db = pytorch_episodes_db
        self._index = []
        self._cache = {}
        self._idx_cache = []
        self.window = window
        self._cache_size = cache_size
        self.d4rl = d4rl
        self.keys_to_tokenize = keys_to_tokenize
        self.is_godot = is_godot
        self._run_index()

    def _get_episode(self, idx):
        if idx in self._cache:           
            return self._cache[idx]
        else:
            e = self._episodes_db[idx]
            self._cache[idx] = {k:v.to(self._device).contiguous() for k,v in e.items()}
            self._idx_cache.append(idx)
            while len(self._idx_cache) > self._cache_size:
                idx = self._idx_cache.pop(0)
                del self._cache[idx]
        return self._cache[idx]

    @torch.no_grad()
    def _run_index(self):
        self._index = [i for i in self._episodes_db.get_ids() if self._episodes_db[i][self.keys_to_tokenize[0]].shape[0] > 2*self.window + 1]        

    @torch.no_grad()
    def __getitem__(self, idx):
        _idx = self._index[idx]
        r = self._get_episode(_idx)
        if self.d4rl:
            episode = torch.cat([r[key] for key in self.keys_to_tokenize], dim=1)
            if self.is_godot:
                episode = episode[:, [0,2]]

            sample_index = np.random.randint(len(episode))
            sample = episode[sample_index]

            positive_index = np.random.randint(max(0, sample_index - self.window), min(sample_index + self.window + 1, len(episode)))
            positive = episode[positive_index]
            
            negative_indices = list(range(sample_index - self.window)) + list(range(sample_index + self.window + 1, len(episode)))
            if len(negative_indices) == 0:
                print(len(episode))
            negative = episode[np.random.choice(negative_indices)]
        else:
            raise NotImplementedError
    
        return sample, positive, negative

    def __len__(self):
        return len(self._index)
    
import os
def get_leaf_paths(root_folder, keys):
    """
    This function returns the relative path of all folders in the root folder that contain
    a file or folder which name is in the keys list
    """
    leaf_paths = []
    def dfs(current_folder, path):
        for item in os.listdir(current_folder):
            item_path = os.path.join(current_folder, item)
            if (item in keys) and (path not in leaf_paths):
                leaf_paths.append(os.path.join(*path))
            if os.path.isdir(item_path):
                dfs(item_path, path + [item])
    # For at least one level deep folders
    dfs(root_folder, ['./'])
    return leaf_paths