"""
  Gradient boosted regressor sEH proxy model.
  Trained on neural net proxy's predictions on
  34M molecules from block18, stop6.
  Attains pearsonr=0.90 on data set.
"""

import pickle
import numpy as np, pandas as pd
from sklearn.ensemble import HistGradientBoostingRegressor


# To run this script, you would need to downgrade scipy==1.0.2, other main script would need higher version of the scipy so first downgrade, run this code, then upgrade
class sEH_GBR_Proxy:
  def __init__(self):
    with open('sehstr_gbtr.pkl', 'rb') as f:
      self.model = pickle.load(f)

    blocks = pd.read_json('block_18.json')
    self.num_blocks = len(blocks)
    
    self.symbols = '0123456789abcdefghijklmnopqrstuvwxyz' + \
              'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\()*+,-./:;<=>?@[\]^_`{|}~'

  def predict_state(self, state):
    x_ft = self.featurize(state.content)
    return self.model.predict(x_ft)[0]

  def featurize(self, string):
    x_ft = np.concatenate([self.symbol_ohe(c) for c in string])
    return x_ft.reshape(1, -1)

  def symbol_ohe(self, symbol):
    zs = np.zeros(self.num_blocks)
    zs[self.symbols.index(symbol)] = 1.0
    return zs


def test():
  # Change this to attrdict
  model = sEH_GBR_Proxy()
  test_string = '1234'
  # print(model.featurize(test_string).shape)
  # pred = model.model.predict(model.featurize(test_string))
  # print(pred)

  features = np.zeros((1,108))
  features[0, :72] = model.featurize(test_string)
  pred = model.model.predict(features)
  print(pred)

  test_string = '123456'
  print(model.featurize(test_string).shape)
  pred = model.model.predict(model.featurize(test_string))
  print(pred)
  return

# generate predictions for all strings
# def generate_allpreds():
#   model = sEH_GBR_Proxy()
#   # string is the 18^6 = 34M strings
#   res = []
#   for i in range(1, 18**6):
#     state = np.zeros(18 * 6)
#     for j in range(6):
#       state[i % 18 + j * 18] = 1
#       i = i // 18
#     res.append(model.model.predict(state))
#   print(np.max(res))
#   print(np.min(res))
#   print(np.mean(res))
#   with open('sehstr_gbtr_allpreds.pkl', 'wb') as f:
#     pickle.dump(res, f)

def generate_allpreds():
    model = sEH_GBR_Proxy()
    total = 18**6  # Total number of combinations (≈34 million)
    batch_size = 10000  # Adjust this based on your available memory
    predictions = []
    
    for start in range(0, total, batch_size):
        end = min(total, start + batch_size)
        indices = np.arange(start, end)
        # Convert each index to a 6-digit base-18 representation
        # np.unravel_index returns a tuple of 6 arrays; transpose to shape (batch_size, 6)
        digits = np.array(np.unravel_index(indices, (18,)*6)).T  # shape: (batch_size, 6)
        
        # Create a batch of one-hot encoded states; each state has length 18*6 = 108
        state_batch = np.zeros((len(indices), 18 * 6))
        for pos in range(6):
            # For each state in the batch, set the corresponding one-hot position.
            state_batch[np.arange(len(indices)), pos * 18 + digits[:, pos]] = 1
        
        # Model prediction on the batch (ensure that your model supports batch predictions)
        preds = model.model.predict(state_batch)
        predictions.append(preds)
    
    # Concatenate all batch predictions into one array
    predictions = np.concatenate(predictions, axis=0)
    print(np.max(predictions))
    print(np.min(predictions))
    print(np.mean(predictions))
    
    with open('sehstr_gbtr_allpreds.pkl', 'wb') as f:
        pickle.dump(predictions, f)

if __name__ == '__main__':
  # test()
  generate_allpreds()