import os
import pickle
import sys

import lmdb
from torch.utils.data import DataLoader

sys.path.append(".")
sys.path.append("..")

def raw_reader(path):
    with open(path, 'rb') as f:
        bin_data = f.read()
    return bin_data


def dumps_data(obj):
    """
    Serialize an object.
    Returns:
        Implementation-dependent bytes-like object
    """
    return pickle.dumps(obj)


def loads_data(buf):
    """
    Args:
        buf: the output of `dumps`.
    """
    return pickle.loads(buf)


def dataset2lmdb(dataset, lmdb_path, write_frequency=5000):
    data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)

    isdir = os.path.isdir(lmdb_path)
    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir,
                   map_size=1099511627776 * 2, readonly=False,
                   meminit=False, map_async=True)

    txn = db.begin(write=True)
    for idx, data in enumerate(data_loader):
        image, label = data[0]

        txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((image, label)))
        if idx % write_frequency == 0:
            print("[%d/%d]" % (idx, len(data_loader)))
            txn.commit()
            txn = db.begin(write=True)

    # finish iterating through dataset
    txn.commit()
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', dumps_data(keys))
        txn.put(b'__len__', dumps_data(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

if __name__ == "__main__":
    data_path = "./data/caltech101"
    from dataset_coop import COOPDataset
    for dataset_name in ["caltech101"]:
        data_path = os.path.join('./data', dataset_name)
        for split in ['train', 'val', 'test']:
            dataset = COOPDataset(root=data_path, split=split, loader='raw')
            dataset2lmdb(dataset, os.path.join(data_path, f"{split}.lmdb"))