'''
    GFP
    Transformer Proxy
    Start from scratch
'''

import copy, pickle, functools
import numpy as np
from tqdm import tqdm
import pandas as pd
import torch
from polyleven import levenshtein

import gflownet.trainers as trainers
from gflownet.GFNs import models
from gflownet.MDPs import seqpamdp, seqinsertmdp, seqarmdp
from gflownet.monitor import TargetRewardDistribution, Monitor, diversity

import flexs
from flexs import baselines
import flexs.utils.sequence_utils as s_utils

def dynamic_inherit_mdp(base, args):

  class RNAMDP(base):
    def __init__(self, args):
      super().__init__(args,
                       alphabet=["U", "C", "G", "A"],
                       forced_stop_len=args.rna_length)
      self.args = args
      self.rna_task = args.rna_task
      self.rna_length = args.rna_length
      args.alphabet = ["U", "C", "G", "A"]
      # dataset = UTRDataset()
      # self.proxy_model = TransformerOracle(dataset, noise_std=0.1)
      print(f'Loading data ...')
      problem = flexs.landscapes.rna.registry()[f'L{self.rna_length}_RNA{self.rna_task}']
      self.oracle = flexs.landscapes.RNABinding(**problem['params'])
      print(problem)
      
      # Dataset is collecting by querying 100,000 rna sequences 
      # which is generated by random mutation from starting sequence
      with open(f"datasets/L{self.rna_length}_RNA{self.rna_task}/rewards.pkl", "rb") as f:
        self.rewards = pickle.load(f)
        
      # scale rewards
      py = np.array(list(self.rewards))

      self.SCALE_REWARD_MIN = args.scale_reward_min
      self.SCALE_REWARD_MAX = args.scale_reward_max
      self.REWARD_EXP = args.reward_exp

      py = np.maximum(py, self.SCALE_REWARD_MIN)
      py = py ** self.REWARD_EXP
      self.scale = self.SCALE_REWARD_MAX / max(py)
      py = py * self.scale

      self.scaled_rewards = py

      
      self.mode_info_file = args.mode_info_file + f"L{self.rna_length}_RNA{self.rna_task}/mode_info.pkl"
      with open(self.mode_info_file, 'rb') as f:
        mode_info = pickle.load(f)
      
      self.modes = [mode_info['modes'], mode_info['modes_hamming_ball1'], mode_info['modes_hamming_ball2']]
     
      # define modes as top % of xhashes.
      #mode_percentile = 0.001
      #self.mode_r_threshold = np.percentile(py, 100*(1-mode_percentile))
      #del self.rewards
      
    
    # Core
    @functools.lru_cache(maxsize=None) # Remove this for subGFN
    def reward(self, x):
      #assert x.is_leaf, 'Error: Tried to compute reward on non-leaf node.'
      # return self.scaled_oracle[x]
      # pred = self.proxy_model.params["model"].predict(
      #   {"input_ids": np.array([self.char_to_idx[c] for c in list(x.content)]).reshape(1, -1)}
      # )[0].item()
      # print(x.content)

      with torch.no_grad():
        temp = copy.deepcopy(x)
        if len(x.content) != self.rna_length:
          temp = copy.deepcopy(x)
          right_am = 'A'*((self.rna_length-len(x.content))//2)
          left_am = 'A'*((self.rna_length-len(x.content))//2)
          if len(x.content) % 2 != 0:
            left_am += 'A'
          temp.content = left_am + temp.content + right_am
          r = self.oracle.get_fitness([temp.content]).item()
        else:
          r = self.oracle.get_fitness([x.content]).item()
      
      r = np.maximum(r, self.SCALE_REWARD_MIN)
      r = r ** self.REWARD_EXP
      r = r * self.scale
      return r


    @functools.lru_cache(maxsize=None) # Remove this for subGFN
    def reward_valid(self, x):
      #assert x.is_leaf, 'Error: Tried to compute reward on non-leaf node.'
      # return self.scaled_oracle[x]
      # pred = self.proxy_model.params["model"].predict(
      #   {"input_ids": np.array([self.char_to_idx[c] for c in list(x.content)]).reshape(1, -1)}
      # )[0].item()
      # print(x.content)

      r = self.oracle.get_fitness([x.content]).item()
      r = np.maximum(r, self.SCALE_REWARD_MIN)
      r = r ** self.REWARD_EXP
      r = r * self.scale
      return r


    def is_mode(self, x, r, g = 0):
      return x.content in self.modes[g]

    def unnormalize(self, r):
      r = r / self.scale
      r = r ** (1 / self.REWARD_EXP)
      return r

    '''
      Interpretation & visualization
    '''
    def dist_func(self, state1, state2):
      """ States are SeqPAState or SeqInsertState objects. """
      return levenshtein(state1.content, state2.content)

    def make_monitor(self):
      target = TargetRewardDistribution()
      target.init_from_base_rewards(self.scaled_rewards)
      return Monitor(self.args, target, dist_func=self.dist_func,
                     is_mode_f=self.is_mode, callback=self.add_monitor,
                     unnormalize=self.unnormalize)

    def add_monitor(self, xs, rs, allXtoR):
      """ Reimplement scoring with oracle, not unscaled oracle (used as R). """
      tolog = dict()
      return tolog
    
    def reduce_storage(self):
      del self.rewards
      del self.scaled_rewards

  return RNAMDP(args)


def main(args):
  print('Running experiment RNA ...')

  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()

  mdp.reduce_storage()

  trainer = trainers.Trainer(args, model, mdp, actor, monitor)
  trainer.learn()
  return






def eval(args):
  print('Running evaluation RNA ...')
  
  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()

  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  if args.ckpt == -1: # final
    model.load_for_eval_from_checkpoint(ckpt_path + '/' + 'final.pth')
  else:
    model.load_for_eval_from_checkpoint(ckpt_path + '/' + f'round_{args.ckpt}.pth')
    
  # evaluate
  with torch.no_grad():
    eval_samples = model.batch_fwd_sample(args.eval_num_samples, epsilon=0.0)
    
  allXtoR = dict()
  for exp in eval_samples:
    if exp.x not in allXtoR:
      allXtoR[exp.x] = exp.r 
  
  round_num = 1
  monitor.log_samples(round_num, eval_samples)
  log = monitor.eval_samplelog(model, round_num, allXtoR)

  # save results
  result_path = args.saved_models_dir + args.run_name
  if args.ckpt == -1: # final
    result_path += '/' + 'final_eval_samples.pkl'
  else:
    result_path += '/' + f'round_{args.ckpt}_eval_samples.pkl'
    
  with open(result_path, "wb") as f:
    pickle.dump(eval_samples, f)

def tradeoff(args):
  print('Running evaluation RNA ...')
  
  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()

  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  if args.ckpt == 0:
    pass
  else:
    model.load_for_eval_from_checkpoint(ckpt_path + '/' + f'round_{args.ckpt}.pth')
    
  # remove duplicate modes
  # with open(ckpt_path + '/' + f'round_{args.ckpt}_sample.pkl', 'rb') as f:
  #   eval_samples = pickle.load(f)

  # allXtoR = dict()
  # for exp in eval_samples:
  #   if exp.x not in allXtoR:
  #     allXtoR[exp.x] = exp.r
  
  eval_samples = []
  allXtoR = dict()
  
  # evaluate
  with torch.no_grad():
    mcmc_samples = model.batch_mh_sample(args.eval_num_samples, (args.num_active_learning_rounds - args.ckpt) // (args.eval_num_samples // args.num_samples_per_online_batch) + 1, epsilon=0.0, k=7)
  
  for exp in mcmc_samples:
    if exp.x not in allXtoR:
      allXtoR[exp.x] = exp.r
  eval_samples.extend(mcmc_samples)

  round_num = 1
  monitor.log_samples(round_num, eval_samples)
  log = monitor.eval_samplelog(model, round_num, allXtoR)

  # save results
  result_path = args.saved_models_dir + args.run_name
  sample_save_path = args.saved_models_dir + args.run_name
  eval_samples_save_path = args.saved_models_dir + args.run_name

  result_path += '/' + f'round_{args.ckpt}_mcmc.pkl'
  sample_save_path += '/' + f'round_{args.ckpt}_sample_mcmc.pkl'
  eval_samples_save_path += '/' + f'round_{args.ckpt}_eval_samples.pkl'
    
  with open(result_path, "wb") as f:
    pickle.dump(log, f)
    
  with open(sample_save_path, "wb") as f:
    pickle.dump(eval_samples, f)
    
  with open(eval_samples_save_path, "wb") as f:
    pickle.dump(eval_samples[-args.eval_num_samples:], f)

def get_neighbors(x, mdp):
  neighbors = []
  for i in range(mdp.forced_stop_len):
    for j in mdp.alphabet:
      neighbor = list(x.content[:])
      if neighbor[i] != j:
        neighbor[i] = j 
        neighbor = "".join(neighbor)
        neighbors.append(neighbor)
  return neighbors   

def number_of_modes(args):
  print('Running evaluation RNA ...')

  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  with open(ckpt_path + '/' + "final_sample.pkl", "rb") as f:
    generated_samples = pickle.load(f)
  
  allXtoR = dict()
  batch_size = args.num_samples_per_online_batch
  number_of_modes = np.zeros((len(generated_samples) // batch_size, ))
  with tqdm(total=len(generated_samples)) as pbar:
    for i in range(0, len(generated_samples), batch_size):
      for exp in generated_samples[i: i+batch_size]:
        if not exp.x in allXtoR:        
          allXtoR[exp.x] = mdp.unnormalize(exp.r)
          if allXtoR[exp.x] > args.threshold:
            neighbor_x = get_neighbors(exp.x, mdp)
            neighbor_r = [mdp.oracle.get_fitness([x]).item() for x in neighbor_x]
            if allXtoR[exp.x] > max(neighbor_r):
              number_of_modes[i // batch_size] += 1
      pbar.update(batch_size)
      pbar.set_postfix(number_of_modes=np.sum(number_of_modes))
  print(np.sum(number_of_modes))
  np.savez_compressed(ckpt_path + '/' + f'number_of_modes_{args.threshold}.npz', modes=number_of_modes) 


def analysis(args):
  print('Running evaluation RNA ...')

  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)
  monitor = mdp.make_monitor()
  
  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  ckpt_path = args.saved_models_dir + args.run_name
  print(ckpt_path)
  with open(ckpt_path + '/' + "final_eval_samples.pkl", "rb") as f:
    all_samples = pickle.load(f)

  x_to_r = {exp.x: exp.r for exp in all_samples}
  sorted_x = sorted(x_to_r, key=x_to_r.get, reverse=True)
  
  for k in [10, 100, 128]:
    top_x = sorted_x[:k]
    top_rs = [x_to_r[x] for x in top_x]
    
    print(f"Top-{k} Reward: {np.mean([mdp.unnormalize(r) for r in top_rs]):.3f}")
    print(f"Top-{k} Diversity: {diversity(top_x, monitor.dist_func):.3f}")
    print()