import os
import sys

# Add parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))

from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import numpy as np
from config import get_default_config



class MullerBrownDataset(Dataset):
    def __init__(self, file_path):

        data = np.load(file_path, allow_pickle=True)
        self.positions = data['R'].reshape(-1, 1)  # shape: (N, 1)

    def __len__(self):
        return len(self.positions)
    
    def __getitem__(self, idx):
        return self.positions[idx]
    
def get_mb_dataloader(config):

    file_path = config["dataset"]["mb_datafile"]
    num_samples = config["dataset"]["num_samples"]
    shuffle = config["dataset"]["shuffle"]
    drop_last = config["dataset"]["drop_last"]
    batch_size = config["trainer"]["batch"]

    dataset = MullerBrownDataset(file_path=file_path)
    print(f"Dataset size: {len(dataset)}")
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    sampler = SubsetRandomSampler(indices)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=shuffle, drop_last=drop_last)
    return dataloader

if __name__ == "__main__":
    config = get_default_config()
    dataloader = get_mb_dataloader(config)
    for x_batch, f_batch in dataloader:
        print(x_batch.shape)
        print(f_batch.shape)
        break