import random
import pickle
import numpy as np
import pandas as pd
import functools
from tqdm import tqdm
from itertools import islice
import gc
import functools


#Featurize
blocks_file = 'datasets/sehstr/block_18.json'
blocks_df = pd.read_json(blocks_file)
symbols = '0123456789abcdefghijklmnopqrstuvwxyz' + \
              'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\()*+,-./:;<=>?@[\]^_`{|}~'
num_blocks = len(blocks_df)

@functools.cache
def symbol_ohe(symbol):
    zs = np.zeros(num_blocks)
    zs[symbols.index(symbol)] = 1.0
    return zs

def featurize(x):
    return np.concatenate([symbol_ohe(c) for c in x])

def chunked_items(datasets, chunk_size):
    it = iter(datasets.items())
    while True:
        chunk = list(islice(it, chunk_size))
        if not chunk:
            break
        yield chunk
        


if __name__ == "__main__":     
    dataset_file = 'datasets/sehstr/block_18_stop6.pkl'
    with open(dataset_file, 'rb') as f:
        datasets = pickle.load(f)
        
    X_all = []
    pbar = tqdm(total=len(datasets), desc="featurizing")
    for chunk in chunked_items(datasets, 100_000):
        x_feat = [featurize(x) for x, _ in chunk]
        X_all.extend(x_feat)
        pbar.update(len(chunk))
        del x_feat, chunk
        gc.collect()
    
    X_all = np.array(X_all)
    
    for bsize in [200, 400]:
        with open(f'datasets/proxy/proxy_sample{bsize}.pkl', 'rb') as f:
            model = pickle.load(f) 
        ys = model.predict(X_all)
        with open(f'datasets/proxy_sample{bsize}_allpreds.pkl', 'wb') as f:
            pickle.dump(ys, f)
        print('Saved to file.')