import pickle, functools
import numpy as np
from omegaconf import OmegaConf
import torch
torch.autograd.set_detect_anomaly(True)
import math

import copy, pickle
import numpy as np
from polyleven import levenshtein

import gflownet.trainers_tf as trainers
from gflownet.GFNs import models
from gflownet.MDPs import seqpamdp
from gflownet.monitor import TargetRewardDistribution, Monitor
from datasets.tfbind8 import gbr_proxy


def dynamic_inherit_mdp(base, args):
  class TFBind8MDP(base):
    def __init__(self, args):
      super().__init__(args,
                       alphabet=list('0123'),
                       forced_stop_len=8)
      self.args = args
      self.proxy_model = gbr_proxy.sEH_GBR_Proxy(args)

      with open(args.offline_dataset, 'rb') as f:
        offline_datasets = pickle.load(f)
      munge = lambda x: ''.join([str(c) for c in list(x)])
      keys = [self.state(munge(k), True) if not hasattr(k, "content") 
              else k for k in offline_datasets.keys()]
      self.rewards = np.array(list(offline_datasets.values()))
      self.mean = np.mean(self.rewards)
      self.std = np.std(self.rewards)
      self.scaled_rewards = np.exp(((self.rewards - self.mean) / self.std) * self.args.beta)

      if args.flag=='proxy':
        preds = np.array([self.proxy_model.predict_state(key) for key in keys])
        il_rewards = np.exp(((preds-self.mean) / self.std) * self.args.beta  + self.args.ralpha)
        self.initial_XtoR = dict(zip(keys, il_rewards))
      if args.flag=='true':
        il_rewards = np.exp(((self.rewards - self.mean) / self.std) * self.args.beta + self.args.ralpha)
        self.initial_XtoR = dict(zip(keys, il_rewards))

      # Modes
      with open('datasets/tfbind8/modes_tfbind8.pkl', 'rb') as f:
        modes = pickle.load(f)
      self.modes = set([self.state(munge(x), is_leaf=True) for x in modes])

    # Core
    def reward(self, x):
      assert x.is_leaf, 'Error: Tried to compute reward on non-leaf node.'
      pred = self.proxy_model.predict_state(x)
      r = (pred-self.mean) / self.std
      r = math.exp(r * self.args.beta - self.args.ralpha)
      return r

    def is_mode(self, x, r):
      return x in self.modes

    '''
      Interpretation & visualization
    '''
    def dist_func(self, state1, state2):
      """ States are SeqPAState or SeqInsertState objects. """
      return levenshtein(state1.content, state2.content)

    def make_monitor(self):
      target = TargetRewardDistribution()
      target.init_from_base_rewards(self.scaled_rewards)
      return Monitor(self.args, target, dist_func=self.dist_func,
                     is_mode_f=self.is_mode)

    def reduce_storage(self):
      del self.rewards
      del self.scaled_rewards

  return TFBind8MDP(args)


if __name__ == "__main__":
  print('Running experiment tfbind8 ...')
  import argparse
  args = OmegaConf.load(f"./settings/tf8_1000_proxy_2_0.yaml")
  args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  print(f"{args.device=}")

  base = seqpamdp.SeqPrependAppendMDP
  actorclass = seqpamdp.SeqPAActor  
  mdp = dynamic_inherit_mdp(base, args)
  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()
  trainer = trainers.TrainerIL(args, model, mdp, actor, monitor)
  trainer.learn()
  mdp.reduce_storage()