"""
Revised from https://github.com/rixwew/pytorch-fm/blob/master/torchfm/dataset/criteo.py
"""

import math
import shutil
import struct
from collections import defaultdict
from functools import lru_cache
from pathlib import Path

import os
import lmdb
import numpy as np
import torch.utils.data
from tqdm import tqdm


class AvitoDataset(torch.utils.data.Dataset):
    """
    Kdd12 Dataset

    Data prepration:
        * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature

    :param dataset_path: kdd train_user_group.txt path.
    :param cache_path: lmdb cache path.
    :param rebuild_cache: If True, lmdb cache is refreshed.
    :param min_threshold: infrequent feature threshold.

    """
    def __init__(self, dataset_path=None, cache_path='/home/data/avito/data', rebuild_cache=False, min_threshold=1, mode='train'):
        self.NUM_FEATS = 16
        self.min_threshold = min_threshold
        if rebuild_cache :
            shutil.rmtree(cache_path, ignore_errors=True)
            if dataset_path is None:
                raise ValueError('create cache: failed: dataset_path is None')
            self.__build_cache(dataset_path, cache_path)
        if mode == 'train':
            cache_path = '/home/data/avito/.train'
        elif mode == 'test':
            cache_path = '/home/data/avito/.test'
        elif mode == 'valid':
            cache_path = '/home/data/avito/.valid'
        elif mode == 'train_valid':
            cache_path = '/home/data/avito/.train_valid'
        print(cache_path)
        self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True)
        with self.env.begin(write=False) as txn:
            self.length = txn.stat()['entries'] - 1
            self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32)
            
        self.field_id = ['AdID', 'Position', 'UserID', 'IsUserLoggedOn',
               'SearchQuery', 'LocationID_x', 'CategoryID_x', 'Price', 'IsContext',
               'UserAgentID', 'UserAgentOSID', 'UserDeviceID', 'UserAgentFamilyID',
               'day', 'hour', 'minute']

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            np_array = np.frombuffer(
                txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long)
        return index, np_array[1:], np_array[0]

    def __len__(self):
        return self.length

    def __build_cache(self, path, cache_path):
        feat_mapper, defaults = self.__get_feat_mapper(path)
        with lmdb.open(cache_path, map_size=int(1e11)) as env:
            field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32)
            for i, fm in feat_mapper.items():
                field_dims[i - 1] = len(fm) + 1
            with env.begin(write=True) as txn:
                txn.put(b'field_dims', field_dims.tobytes())
            for buffer in self.__yield_buffer(path, feat_mapper, defaults):
                with env.begin(write=True) as txn:
                    for key, value in buffer:
                        txn.put(key, value)

    def __get_feat_mapper(self, path):
        feat_cnts = defaultdict(lambda: defaultdict(int))
        item_idx = 0
        with open(path) as f:
            f.readline()
            pbar = tqdm(f, mininterval=1, smoothing=0.1)
            pbar.set_description('Create Avito dataset cache: counting features')
            for line in pbar:
                values = line.rstrip('\n').split(',')
                if len(values) != self.NUM_FEATS + 1:
                    continue
                for i in range(1, self.NUM_FEATS + 1):
                    feat_cnts[i][values[i]] += 1
                item_idx += 1
                
        feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()}
        feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()}
        defaults = {i: len(cnt) for i, cnt in feat_mapper.items()}
        return feat_mapper, defaults

    def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)):
        item_idx = 0
        buffer = list()
        with open(path) as f:
            f.readline()
            pbar = tqdm(f, mininterval=1, smoothing=0.1)
            pbar.set_description('Create Avito dataset cache: setup lmdb')
            for line in pbar:
                values = line.rstrip('\n').split(',')
                if len(values) != self.NUM_FEATS + 1:
                    continue
                np_array = np.zeros(self.NUM_FEATS + 2, dtype=np.uint32)
                np_array[0] = int(values[0][0])
                for i in range(1, self.NUM_FEATS + 1):
                    np_array[i] = feat_mapper[i].get(values[i], defaults[i])
                if values[-3] == '2015-05-20':
                    np_array[-1] = 2
                elif values[-3] == '2015-05-19':
                    np_array[-1] = 1
                else:
                    np_array[-1] = 0
                buffer.append((struct.pack('>I', item_idx), np_array.tobytes()))
                item_idx += 1
                if item_idx % buffer_size == 0:
                    yield buffer
                    buffer.clear()
            yield buffer