import os
try:
    import ujson as json
except:
    import json 

import torch

from bbo.datasets.base import SimpleDataset
from bbo.datasets.utils import X_transform, Y_transform


class DatasetBase:
    def __init__(
        self,
        root_dir='hpob-data/',
    ):
        """
        The dataset json file is like: 
        { "search_space_id_1" : { "dataset_id_1": {"X": [[1,1], [0,2]], "y": [[0.9], [0.1]]},
                                { "dataset_id_2": ... },
          "search_space_id_2" : ...
        }
        """
        self._load_data(root_dir)

    def _load_data(self, root_dir):
        meta_train_path = os.path.join(root_dir, 'meta-train-dataset.json')
        meta_test_path = os.path.join(root_dir, 'meta-test-dataset.json')
        meta_validation_path = os.path.join(root_dir, 'meta-validation-dataset.json')

        with open(meta_train_path, 'rb') as f:
            self.meta_train_data = json.load(f)
        
        with open(meta_test_path, 'rb') as f:
            self.meta_test_data = json.load(f)

        with open(meta_validation_path, 'rb') as f:
            self.meta_validation_data = json.load(f)

    def get_search_spaces(self):
        return list(self.meta_test_data.keys())

    def get_datasets(self, search_space_id, mode):
        assert mode in ['train', 'test', 'validation']
        method_name = 'meta_' + mode + '_data'
        return list(getattr(self, method_name)[search_space_id].keys())


def load_hpob_datasets(root_dir):
    if root_dir.startswith('~'):
        root_dir = os.path.expanduser(root_dir)
    dataset_base = DatasetBase(root_dir)
    train_dataset, test_dataset, validation_dataset = dict(), dict(), dict()

    for search_space_id in dataset_base.meta_train_data.keys():
        train_dataset[search_space_id] = dict()
        test_dataset[search_space_id] = dict()
        validation_dataset[search_space_id] = dict()

        for dataset, meta_data in zip(
            (train_dataset, test_dataset, validation_dataset),
            (dataset_base.meta_train_data, dataset_base.meta_test_data, dataset_base.meta_validation_data)
        ):
            for dataset_id in meta_data[search_space_id].keys():
                X, Y = meta_data[search_space_id][dataset_id]['X'], meta_data[search_space_id][dataset_id]['y']
                hpob_dataset = SimpleDataset(torch.tensor(X), torch.tensor(Y))
                dataset[search_space_id][dataset_id] = hpob_dataset

    # check
    # assert len(train_dataset) == len(test_dataset) == len(validation_dataset) == len(dataset_base.meta_train_data)
    # for search_space_id in train_dataset.keys():
    #     for dataset, meta_data in zip(
    #         (train_dataset, test_dataset, validation_dataset),
    #         (dataset_base.meta_train_data, dataset_base.meta_test_data, dataset_base.meta_validation_data)
    #     ):
    #         for dataset_id in dataset[search_space_id]:
    #             assert len(dataset[search_space_id][dataset_id]) == len(meta_data[search_space_id][dataset_id]['X'])

    return train_dataset, test_dataset, validation_dataset


def load_hpob_pretrain_dataset(root_dir, search_space_id, lb, ub, device):
    train, test, validation = \
        load_hpob_datasets(os.path.join(root_dir, 'hpob-data'))
    train_id2dataset = train[search_space_id]
    test_id2dataset = test[search_space_id]
    validation_id2dataset = validation[search_space_id]

    # preprocess
    normalize_param = dict()
    for dataset_id in train_id2dataset.keys():
        X, Y = train_id2dataset[dataset_id].X, train_id2dataset[dataset_id].Y
        dim = X.shape[-1]
        X = X_transform(X, lb, ub, device)
        Y, mean, std = Y_transform(Y, device=device)

        train_id2dataset[dataset_id] = SimpleDataset(X, Y)
        normalize_param[dataset_id] = (mean, std)
    
    # TODO: preprocess the test and validation data
    for dataset_id in validation_id2dataset.keys():
        X, Y = validation_id2dataset[dataset_id].X, validation_id2dataset[dataset_id].Y
        dim = X.shape[-1]
        X = X_transform(X, lb, ub, device)
        Y, _, _ = Y_transform(Y, device=device)
        validation_id2dataset[dataset_id] = SimpleDataset(X, Y)

    return train_id2dataset, None, validation_id2dataset