import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.transforms.functional as fn

import json

import os

from torchvision.datasets import DatasetFolder

def label_to_state(label, tm, label_to_synset, synset_to_node):
    node_idx = synset_to_node[label_to_synset[str(label)]]
    state = tm[:, node_idx]
    return state

def getDataset(tm, l2s_path='.\ImageNet\label_to_synset_old.json', s2n_path='.\ImageNet\synset_to_node_old.json', imgnet_path="./ImageNet/val"):
    
    with open(l2s_path, 'r') as f:
        label_to_synset = json.load(f)

    with open(s2n_path, 'r') as f:
        synset_to_node = json.load(f)

    transform_label = lambda x: label_to_state(x, tm, label_to_synset, synset_to_node)

    dataset = ImageFolder(imgnet_path,
                        transforms.Compose([
                            transforms.RandomResizedCrop(227),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                std=[0.229, 0.224, 0.225]),
                            ]
                        ),
                        transform_label
                        )

    return dataset