import pickle
import numpy as np, pandas as pd
from sklearn.ensemble import HistGradientBoostingRegressor


class sEH_GBR_Proxy:
  def __init__(self, args):
    with open(f'datasets/tfbind10/proxy/proxy_sample{args.Bsize}.pkl', 'rb') as f:
      self.model = pickle.load(f)

    self.num_blocks = 4
    
    self.symbols = '0123456789abcdefghijklmnopqrstuvwxyz' + \
              'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\()*+,-./:;<=>?@[\]^_`{|}~'

  def predict_state(self, state):
    x_ft = self.featurize(state.content)
    return self.model.predict(x_ft)[0]

  def featurize(self, string):
    x_ft = np.concatenate([self.symbol_ohe(c) for c in string])
    return x_ft.reshape(1, -1)

  def symbol_ohe(self, symbol):
    zs = np.zeros(self.num_blocks)
    zs[self.symbols.index(symbol)] = 1.0
    return zs


def test():
  from attrdict import AttrDict
  args = {'blocks_file': 'datasets/sehstr/block_18.json'}
  args = AttrDict(args)
  model = sEH_GBR_Proxy(args)
  
  test_string = '012345'
  pred = model.model.predict(model.featurize(test_string))
  print(pred)
  return

if __name__ == '__main__':
  test()