from box import Box


def get_dataset_class(dataset_name):
    """Return the algorithm class with the given name."""
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]


class AMAZON_REVIEWS(object):
    def __init__(self):
        super(AMAZON_REVIEWS, self).__init__()        
        self.class_names = ['positive', 'negative']
        self.sequence_len = 128
        self.scenarios = [("books", "dvd"), 
                           ("books", "electronics"),
                           ("books", "kitchen"),
                           ("dvd", "books"), ("dvd", "electronics"), ("dvd", "kitchen"),
                           ("electronics", "books"), ("electronics", "dvd"), ("electronics", "kitchen"),
                           ("kitchen", "books"),("kitchen", "dvd"), ("kitchen", "electronics")
                          ]
        self.num_classes = 2
        self.shuffle = True
        self.drop_last = False
        self.normalize = True

        self.dataloader = Box({
            "reset_and_reload_memory": False,
            "dataset": "AmazonReviews",
            "AmazonReviews": {
                "filename":"amazon.mat",
                "n_features": 5000,
                "num_workers": 8,
                "normalize": True,
                "domains": ["books", "dvd", "electronics", "kitchen"],
                "data_root": "data/amazon_reviews"
            }
        })

