"""
  seh as string
"""
import pickle, functools
import numpy as np
from omegaconf import OmegaConf
import torch
torch.autograd.set_detect_anomaly(True)
import math

import gflownet.trainers_sehstr as trainers
from gflownet.MDPs import molstrmdp
from gflownet.monitor import TargetRewardDistribution, Monitor
from gflownet.GFNs import models

from datasets.sehstr import gbr_proxy

from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from rdkit.DataStructs import FingerprintSimilarity


    
class SEHstringMDP(molstrmdp.MolStrMDP):
  def __init__(self, args):
    super().__init__(args)
    self.args = args
    assert args.blocks_file == 'datasets/sehstr/block_18.json', 'ERROR - x_to_r and rewards are designed for block_18.json'

  def reward(self, x):
    raise NotImplementedError("Subclasses must implement this method")
      
  def is_mode(self, x, r):
    return r >= self.mode_r_threshold

  # Diversity
  def dist_states(self, state1, state2):
    """ Tanimoto similarity on morgan fingerprints """
    fp1 = self.get_morgan_fp(state1)
    fp2 = self.get_morgan_fp(state2)
    return 1 - FingerprintSimilarity(fp1, fp2)

  @functools.cache
  def get_morgan_fp(self, state):
    mol = self.state_to_mol(state)
    fp = GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
    return fp

  def make_monitor(self):
    """ Make monitor, called during training. """
    target = TargetRewardDistribution()
    target.init_from_base_rewards(self.scaled_rewards)
    return Monitor(self.args, target, dist_func=self.dist_states,
                   is_mode_f=self.is_mode)

  def reduce_storage(self):
    del self.rewards
    del self.scaled_rewards


class SEH_gfn(SEHstringMDP):
  def __init__(self, args):
    super().__init__(args)
    self.args = args
    self.proxy_model = gbr_proxy.sEH_GBR_Proxy(args)

    with open(args.offline_dataset, 'rb') as f:
      offline_datasets = pickle.load(f)
    self.rewards = np.array(list(offline_datasets.values()))
    self.mean = np.mean(self.rewards)
    self.std = np.std(self.rewards)

    with open(f'datasets/sehstr/proxy_sample{args.Bsize}_allpreds.pkl', 'rb') as f:
      all_preds = pickle.load(f)
    preds_values = np.array(list(all_preds))
    preds_mean = np.mean(preds_values)
    preds_std = np.std(preds_values)
    self.scaled_rewards = np.exp((preds_values-preds_mean) / preds_std)

    # define modes as top % of xhashes.
    mode_percentile = 0.001
    self.mode_r_threshold = np.percentile(self.scaled_rewards, 100*(1-mode_percentile))

  #@functools.cache
  def reward(self, x):
    assert x.is_leaf, 'Error: Tried to compute reward on non-leaf node.'
    pred = self.proxy_model.predict_state(x)
    r = (pred-self.mean) / self.std
    r = math.exp(r * self.args.beta - self.args.ralpha)
    return r


class SEH_il(SEHstringMDP):
  def __init__(self, args):
    super().__init__(args)
    self.args = args
    self.proxy_model = gbr_proxy.sEH_GBR_Proxy(args)

    with open(args.offline_dataset, 'rb') as f:
      offline_datasets = pickle.load(f)
    keys = [self.state(k, True) if not hasattr(k, "content") else k for k in offline_datasets.keys()]
    self.rewards = np.array(list(offline_datasets.values()))
    self.mean = np.mean(self.rewards)
    self.std = np.std(self.rewards)
    self.scaled_rewards = np.exp(((self.rewards - self.mean) / self.std) * self.args.beta)
    
    if args.flag=='proxy':
      preds = np.array([self.proxy_model.predict_state(key) for key in keys])
      il_rewards = np.exp(((preds-self.mean) / self.std) * self.args.beta  + self.args.ralpha)
      self.initial_XtoR = dict(zip(keys, il_rewards))
    if args.flag=='true':
      il_rewards = np.exp(((self.rewards - self.mean) / self.std) * self.args.beta + self.args.ralpha)
      self.initial_XtoR = dict(zip(keys, il_rewards))
      
    # define modes as top % of xhashes.
    mode_percentile = 0.001
    self.mode_r_threshold = np.percentile(self.scaled_rewards, 100*(1-mode_percentile))

  @functools.cache
  def reward(self, x):
    assert x.is_leaf, 'Error: Tried to compute reward on non-leaf node.'
    pred = self.proxy_model.predict_state(x)
    r = (pred-self.mean) / self.std
    r = math.exp(r * self.args.beta - self.args.ralpha)
    return r



if __name__ == "__main__":
  print('Running experiment sehstr ...')
  args = OmegaConf.load(f"./settings/seh_1000_proxy_2_0.5.yaml")
  args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  print(f"{args.device=}")
  
  if args.task == "gfn":
    mdp = SEH_gfn(args)
    actor = molstrmdp.MolStrActor(args, mdp)
    model = models.make_model(args, mdp, actor)
    monitor = mdp.make_monitor()
    mdp.reduce_storage()
    trainer = trainers.TrainerGFN(args, model, mdp, actor, monitor)
    trainer.learn()
      
  if args.task == "il":
    mdp = SEH_il(args)
    actor = molstrmdp.MolStrActor(args, mdp)
    model = models.make_model(args, mdp, actor)
    monitor = mdp.make_monitor()
    trainer = trainers.TrainerIL(args, model, mdp, actor, monitor)
    trainer.learn()
    mdp.reduce_storage()