import torch
import numpy as np
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

def store(obs, memory):
    """
    Store observations into memory.
    """
    for key in obs:
        if key not in memory:
            memory[key] = []
        memory[key].append(obs[key])

def get(memory):
    """
    Convert memory to numpy arrays.
    """
    return {key: np.array(val) for key, val in memory.items()}

def sample(memory, batch_size=64):
    """
    Sample batches from memory and concatenate tensors along the specified dimension.
    """
    obs_dict = get(memory)
    n = len(next(iter(obs_dict.values())))  # Number of samples
    
    # Ensure that we have enough samples to batch
    assert n >= batch_size, "Not enough samples to create a batch"
    
    random_indices = np.random.permutation(n)
    sampler = BatchSampler(SubsetRandomSampler(random_indices), batch_size, drop_last=True)
    
    # Print sampler information for debugging
    print("Sampler indices:", list(sampler))
    
    for indices in sampler:
        # print(f"indices = {indices}")
        # for key, val in obs_dict.items():
        #     print(f"len = {len(val)}")
        obs_batch = {key: torch.stack([torch.tensor(val[i]) for i in indices], dim=1).squeeze(2) for key, val in obs_dict.items()}
        
        yield {key: batch for key, batch in obs_batch.items()}

if __name__ == "__main__":
    memory = {'glyphs': [],
              'blstats':[]
              }  # Initialize with keys relevant to your data
    n = 32
    batch_size = 8

    # Populate memory with random tensors
    for i in range(n):
        random_tensor_g = torch.randn(5, 1, 21, 79)
        random_tensor_b = torch.randn(5, 1, 27)
        memory["glyphs"].append(random_tensor_g)
        memory["blstats"].append(random_tensor_b)

    # Sample batches
    batches = list(sample(memory, batch_size))
    
    # Print the generated batches
    for batch in batches:
        print("Batch:")
        for key, value in batch.items():
            print(f"{key}: {value.shape}")
