import torch
import numpy as np
from einops import rearrange

def load_images(dataset_name: str, dataset_root: str):
    if dataset_name == "185":
        images = np.load(dataset_root+"/datasets/murty_185/185_stims.npy")
        # raise AssertionError(images.shape)
        # bchw_images = images
        bchw_images = rearrange(images, 'b w h c -> b c h w')
        bchw_images = bchw_images/255.0

    elif dataset_name == "NSD1000":
        images = torch.load(dataset_root+'/nsd_stimuli_1000_updated.pth', weights_only=True).detach().cpu().numpy()
        images = rearrange(images, "b c h w -> b c h w")
        bchw_images = images

    bchw_images = torch.from_numpy(bchw_images)
    return bchw_images