from pathlib import Path
import shutil
from typing import Dict, Optional, Union

import numpy as np
import torch

from .episode import Episode
from .segment import Segment, SegmentId
from .utils import make_segment


class EpisodeDataset(torch.utils.data.Dataset):
    def __init__(self, directory: Path, name: str, cache_in_ram: bool = False) -> None:
        super().__init__()
        self.directory = Path(directory).expanduser()
        self.name = name
        self.cache_in_ram = cache_in_ram
        self.num_episodes, self.num_steps, self.start_idx, self.lengths, self.cache = None, None, None, None, None
        if not self.directory.is_dir():
            self._init_empty()
        else:
            self._load_info()
            print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.')
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.directory.absolute()}, {self.name})'
    
    def __str__(self) -> str:
        return f'{self.__class__.__name__}\nName: {self.name}\nDirectory: {self.directory.absolute()}\nNum steps: {self.info["num_steps"]}\nNum episode: {self.info["num_episodes"]}'

    @property
    def info_path(self) -> Path:
        return self.directory / 'info.pt'

    @property
    def info(self) -> Dict[str, Union[int, np.ndarray]]:
        return {'num_episodes': self.num_episodes,
                'num_steps': self.num_steps,
                'start_idx': self.start_idx,
                'lengths': self.lengths}
    
    def __len__(self) -> int:
        return self.num_steps
    
    def __getitem__(self, segment_id: SegmentId) -> Segment:
        return self._load_segment(segment_id)

    def _init_empty(self) -> None:
        self.directory.mkdir(parents=True, exist_ok=False)
        self.num_episodes, self.num_steps = 0, 0
        self.start_idx = np.array([], dtype=np.int64)            
        self.lengths = np.array([], dtype=np.int64)
        self.cache = []
        self.save_info()
    
    def _load_info(self) -> None:
        info = torch.load(self.info_path)
        self.num_steps = info['num_steps']
        self.num_episodes = info['num_episodes']
        self.start_idx = info['start_idx']
        self.lengths = info['lengths']
        self.cache = [None] * self.num_episodes
    
    def save_info(self) -> None:
        torch.save(self.info, self.info_path)

    def clear(self) -> None:
        shutil.rmtree(self.directory)
        self._init_empty()
    
    def _get_episode_path(self, episode_id: int) -> Path:
        n = 3 # number of hierarchies
        powers = np.arange(n)
        subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1]
        return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt'
    
    def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment:
        episode = self.load_episode(segment_id.episode_id)
        return make_segment(episode, segment_id, should_pad)
    
    def load_episode(self, episode_id: int) -> Episode:
        should_load = not self.cache_in_ram or self.cache[episode_id] is None
        episode = Episode.load(self._get_episode_path(episode_id), map_location='cpu') if should_load else self.cache[episode_id]
        if self.cache_in_ram and self.cache[episode_id] is None:
            self.cache[episode_id] = episode
        return episode
    
    def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None, save_on_disk: bool = True) -> int:
        if episode_id is None:
            episode_id = self.num_episodes
            self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
            self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
            self.cache.append(None)
            self.num_steps += len(episode)
            self.num_episodes += 1

        else:
            assert episode_id < self.num_episodes
            old_episode = self.load_episode(episode_id) 
            incr_num_steps = len(episode) - len(old_episode)
            self.lengths[episode_id] = len(episode)
            self.start_idx[episode_id + 1:] += incr_num_steps            
            self.num_steps += incr_num_steps
        
        if save_on_disk:
            episode_path = self._get_episode_path(episode_id)
            episode_path.parent.mkdir(parents=True, exist_ok=True)
            episode.save(episode_path.with_suffix('.tmp'))
            episode_path.with_suffix('.tmp').rename(episode_path)

        if self.cache_in_ram:
            self.cache[episode_id] = episode.to('cpu')
            
        return episode_id
    
    def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray:
        return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes
    
    def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray:
        start_idx = self.start_idx[segment_id.episode_id]
        return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop)
