

import pathlib 
import uuid
import os
import datetime
import io
import numpy as np
import jax.numpy as jnp

import torch
from jax.tree_util import tree_map
import yaml
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper
from tqdm import tqdm

from torch.utils.data import Dataset


def numpy_collate(batch):
  return tree_map(np.asarray, torch.utils.data.default_collate(batch))

class NumpyLoader(torch.utils.data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)
    

class StoredDataset(Dataset):
    def __init__(self, data=None, chunk_size=256, chunks=None, datanames=['obs', 'action', 'next_obs', 'reward', 'state', 'next_state'], description=None):
        self.chunk_size = chunk_size
        self.datanames = datanames
        if not data:
            assert chunks and len(chunks) > 0, f'No data provided!'
            self.chunks = chunks
            self.chunk_size = chunk_size
        else:
            assert not chunks, f'New Dataset created. Cannot provide chunks and data at the same time'
            n_chunks = int(np.ceil(len(data) / chunk_size))
            self.chunks = [data[i*chunk_size:(i+1)*chunk_size] for i in range(n_chunks)]
            self.length = len(data)
        self._description = description

    @classmethod
    def load(cls, path):
        path = pathlib.Path(path).expanduser()
        chunks = []
        if os.path.exists(path):
            with open(path / 'dataset.yaml', 'r') as f:
                config = yaml.load(f, Loader=Loader)
                chunk_size = config['chunk_size']
                length = config['size']
                filenames = config['files']
                datanames = config['datanames']
                description = config['description']
            print(f'Dataset ({description}) exists at {path}. Loading {len(filenames)} chunks')
            chunks = [StoredDataset.load_chunk(path / filename, datanames) for filename in tqdm(filenames)]

        else:
            raise ValueError(f'Dataset at {path} does not exists')

        dataset = cls(chunk_size=chunk_size, chunks=chunks, datanames=datanames, description=description)
        dataset.length = length
        return dataset

    @staticmethod
    def save_chunk(chunk, path, datanames):
        timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
        identifier = str(uuid.uuid4().hex)
        filename = path / f'{timestamp}-{identifier}.npz'
        _chunk = list(zip(*chunk))
        _chunk = {k: v for k, v in zip(datanames, _chunk)}
        with io.BytesIO() as f1:
            np.savez_compressed(f1, **_chunk)
            f1.seek(0)
            with filename.open('wb') as f2:
                f2.write(f1.read())
        return filename.name
    
    @staticmethod
    def load_chunk(filename, datanames, allow_pickle=True):
        try:
            with filename.open('rb') as f:
                chunk = np.load(f, allow_pickle=allow_pickle)
                chunk_data = [chunk[k] for k in datanames]
                chunk_data = list(zip(*chunk_data))
                return chunk_data
        except Exception as e:
            raise ValueError(f'Could not load chunk {str(filename)}: {e}')
        
        

    def save(self, dataset_path, description='', ):
        path = pathlib.Path(dataset_path).expanduser()
        os.makedirs(path, exist_ok=True)
        filenames = [str(self.save_chunk(chunk, path, self.datanames)) for chunk in tqdm(self.chunks)]
        with open(path / 'dataset.yaml', 'w') as f:
            config = {
                'chunk_size': self.chunk_size,
                'size': self.length,
                'datanames': self.datanames,
                'description': description,
                'files': filenames
            }
            yaml.dump(config, f)
        print(f'Dataset ({description}) save at {path}')

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            # Handle slice
            start, stop, step = idx.indices(len(self))  # Get start, stop, and step based on the slice
            result = []
            for i in range(start, stop, step):
                chunk_idx = i // self.chunk_size
                elem_idx = i % self.chunk_size
                result.append(self.chunks[chunk_idx][elem_idx])
            return result
        else:
            # Handle single integer index
            chunk_idx = idx // self.chunk_size
            elem_idx = idx % self.chunk_size
            return self.chunks[chunk_idx][elem_idx]
    
    @property
    def description(self):
        return self._description