import torch
import os
import einops
from tqdm import tqdm
import h5py
from torch.utils.data import Dataset
import seaborn as sns
import umap
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split

class DynamicsAnalysisDataset(Dataset):
    def __init__(self, h5_path):
        """
        Args:
            h5_path (str): Path to an HDF5 file or a directory containing HDF5 files.
        """
        self.h5_paths = []
        self.index_map = []  # List of tuples: (file_index, group_name)
        self.h5_files = {}   # Lazy-opened files, keyed by index in self.h5_paths

        # Handle single file or directory
        if os.path.isdir(h5_path):
            # Get all .h5 or .hdf5 files
            self.h5_paths = sorted([
                os.path.join(h5_path, f) for f in os.listdir(h5_path)
                if f.endswith(('.h5', '.hdf5'))
            ])
        elif os.path.isfile(h5_path):
            self.h5_paths = [h5_path]
        else:
            raise ValueError(f"Invalid path: {h5_path}")

        # Build the index map
        for file_idx, path in enumerate(self.h5_paths):
            with h5py.File(path, 'r') as f:
                for name in f.keys():
                    if name.startswith('sample_'):
                        self.index_map.append((file_idx, name))

    def __len__(self):
        return len(self.index_map)

    def _init_file(self, file_idx):
        if file_idx not in self.h5_files:
            self.h5_files[file_idx] = h5py.File(self.h5_paths[file_idx], 'r')
        return self.h5_files[file_idx]

    def __getitem__(self, idx):
        file_idx, group_name = self.index_map[idx]
        h5_file = self._init_file(file_idx)
        grp = h5_file[group_name]

        # Load and convert to torch tensors
        s_t = torch.tensor(grp['s_t'][()], dtype=torch.float32)
        a_t = torch.tensor(grp['a_t'][()], dtype=torch.float32)
        s_t_plus_1 = torch.tensor(grp['s_t_plus_1'][()], dtype=torch.float32)

        return s_t, a_t, s_t_plus_1
    

    def get_batch_from_file_by_index(self, file_idx, num_samples=1):
        """
        Get a batch of data randomly from a specific file by its index.
        """

        if file_idx < 0 or file_idx >= len(self.h5_paths):
            raise IndexError(f"File index {file_idx} out of range.")

        h5_file = self._init_file(file_idx)
        group_names = [name for name in h5_file.keys() if name.startswith('sample_')]
        
        if not group_names:
            raise ValueError(f"No valid groups found in file {self.h5_paths[file_idx]}.")

        selected_group = np.random.choice(group_names, size=num_samples, replace=True)
        
        s_t_list, a_t_list, s_t_plus_1_list = [], [], []
        for group_name in selected_group:
            grp = h5_file[group_name]
            s_t_list.append(torch.tensor(grp['s_t'][()], dtype=torch.float32))
            a_t_list.append(torch.tensor(grp['a_t'][()], dtype=torch.float32))
            s_t_plus_1_list.append(torch.tensor(grp['s_t_plus_1'][()], dtype=torch.float32))

        return torch.stack(s_t_list), torch.stack(a_t_list), torch.stack(s_t_plus_1_list)

    def __del__(self):
        for f in self.h5_files.values():
            f.close()



if __name__ == "__main__":


    
    pass