'''
    GFP
    Transformer Proxy
    Start from scratch
'''

import copy, pickle, functools
import numpy as np
from tqdm import tqdm
import pandas as pd
import torch
from polyleven import levenshtein
from itertools import combinations, product

import gflownet.trainers as trainers
from gflownet.GFNs import models
from gflownet.MDPs import seqpamdp, seqinsertmdp, seqarmdp
from gflownet.monitor import TargetRewardDistribution, Monitor, diversity

import flexs
from flexs import baselines
import flexs.utils.sequence_utils as s_utils

def dynamic_inherit_mdp(base, args):

  class RNAMDP(base):
    def __init__(self, args):
      super().__init__(args,
                       alphabet=["U", "C", "G", "A"],
                       forced_stop_len=args.rna_length)
      self.args = args
      self.rna_task = args.rna_task
      self.rna_length = args.rna_length
      
      # dataset = UTRDataset()
      # self.proxy_model = TransformerOracle(dataset, noise_std=0.1)
      print(f'Loading data ...')
      problem = flexs.landscapes.rna.registry()[f'L{self.rna_length}_RNA{self.rna_task}']
      self.oracle = flexs.landscapes.RNABinding(**problem['params'])
      print(problem)
      
      # Dataset is collecting by querying 100,000 rna sequences 
      # which is generated by random mutation from starting sequence
      with open(f"datasets/L{self.rna_length}_RNA{self.rna_task}/rewards.pkl", "rb") as f:
        self.rewards = pickle.load(f)
        
      # scale rewards
      py = np.array(list(self.rewards))

      self.SCALE_REWARD_MIN = args.scale_reward_min
      self.SCALE_REWARD_MAX = args.scale_reward_max
      self.REWARD_EXP = args.reward_exp

      py = np.maximum(py, self.SCALE_REWARD_MIN)
      py = py ** self.REWARD_EXP
      self.scale = self.SCALE_REWARD_MAX / max(py)
      py = py * self.scale

      self.scaled_rewards = py

      # define modes as top % of xhashes.
      mode_percentile = 0.001
      self.mode_r_threshold = np.percentile(py, 100*(1-mode_percentile))

    # Core
    def reward(self, x):
      assert x.is_leaf, 'Error: Tried to compute reward on non-leaf node.'
      # return self.scaled_oracle[x]
      # pred = self.proxy_model.params["model"].predict(
      #   {"input_ids": np.array([self.char_to_idx[c] for c in list(x.content)]).reshape(1, -1)}
      # )[0].item()
      # print(x.content)
      r = self.oracle.get_fitness([x.content]).item()
      
      r = np.maximum(r, self.SCALE_REWARD_MIN)
      r = r ** self.REWARD_EXP
      r = r * self.scale
      return r

    def is_mode(self, x, r):
      return r >= self.mode_r_threshold
    
    def unnormalize(self, r):
      r = r / self.scale
      r = r ** (1 / self.REWARD_EXP)
      return r

    '''
      Interpretation & visualization
    '''
    def dist_func(self, state1, state2):
      """ States are SeqPAState or SeqInsertState objects. """
      return levenshtein(state1.content, state2.content)

    def make_monitor(self):
      target = TargetRewardDistribution()
      target.init_from_base_rewards(self.scaled_rewards)
      return Monitor(self.args, target, dist_func=self.dist_func,
                     is_mode_f=self.is_mode, callback=self.add_monitor,
                     unnormalize=self.unnormalize)

    def add_monitor(self, xs, rs, allXtoR):
      """ Reimplement scoring with oracle, not unscaled oracle (used as R). """
      tolog = dict()
      return tolog
    
    def reduce_storage(self):
      del self.rewards
      del self.scaled_rewards

  return RNAMDP(args)


def main(args):
  print('Running experiment RNA ...')

  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()

  mdp.reduce_storage()

  trainer = trainers.Trainer(args, model, mdp, actor, monitor)
  trainer.learn()
  return

def eval(args):
  print('Running evaluation RNA ...')
  
  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()

  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  if args.ckpt == -1: # final
    model.load_for_eval_from_checkpoint(ckpt_path + '/' + 'final.pth')
  else:
    model.load_for_eval_from_checkpoint(ckpt_path + '/' + f'round_{args.ckpt}.pth')
    
  # evaluate
  with torch.no_grad():
    eval_samples = model.batch_fwd_sample(args.eval_num_samples, epsilon=0.0)
    
  allXtoR = dict()
  for exp in eval_samples:
    if exp.x not in allXtoR:
      allXtoR[exp.x] = exp.r 
  
  round_num = 1
  monitor.log_samples(round_num, eval_samples)
  log = monitor.eval_samplelog(model, round_num, allXtoR)

  # save results
  result_path = args.saved_models_dir + args.run_name
  if args.ckpt == -1: # final
    result_path += '/' + 'final_eval_samples.pkl'
  else:
    result_path += '/' + f'round_{args.ckpt}_eval_samples.pkl'
    
  with open(result_path, "wb") as f:
    pickle.dump(eval_samples, f)

def tradeoff(args):
  print('Running evaluation RNA ...')
  
  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  actor = actorclass(args, mdp)
  model = models.make_model(args, mdp, actor)
  monitor = mdp.make_monitor()

  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  if args.ckpt == 0:
    eval_samples = []
    allXtoR = dict()
  else:
    model.load_for_eval_from_checkpoint(ckpt_path + '/' + f'round_{args.ckpt}.pth')
  
    # remove duplicate modes
    with open(ckpt_path + '/' + f'round_{args.ckpt}_sample.pkl', 'rb') as f:
      eval_samples = pickle.load(f)

    allXtoR = dict()
    for exp in eval_samples:
      if exp.x not in allXtoR:
        allXtoR[exp.x] = exp.r
  
  # evaluate
  with torch.no_grad():
    mcmc_samples = model.batch_mh_sample(args.num_samples_per_online_batch, args.num_active_learning_rounds - args.ckpt, epsilon=0.0, k=args.k)
  
  for exp in mcmc_samples:
    if exp.x not in allXtoR:
      allXtoR[exp.x] = exp.r
  eval_samples.extend(mcmc_samples)

  round_num = 1
  monitor.log_samples(round_num, eval_samples)
  log = monitor.eval_samplelog(model, round_num, allXtoR)

  # save results
  result_path = args.saved_models_dir + args.run_name
  sample_save_path = args.saved_models_dir + args.run_name
  eval_samples_save_path = args.saved_models_dir + args.run_name

  result_path += '/' + f'round_{args.ckpt}_mcmc.pkl'
  sample_save_path += '/' + f'round_{args.ckpt}_sample_mcmc.pkl'
  eval_samples_save_path += '/' + f'round_{args.ckpt}_eval_samples.pkl'
    
  with open(result_path, "wb") as f:
    pickle.dump(log, f)
    
  with open(sample_save_path, "wb") as f:
    pickle.dump(eval_samples, f)
    
  with open(eval_samples_save_path, "wb") as f:
    pickle.dump(eval_samples[-args.eval_num_samples:], f)

def get_neighbors(x, mdp, hamming_ball=1):
  neighbors = set()
  for idx in combinations(range(mdp.forced_stop_len), hamming_ball):
    for idx_j in product(mdp.alphabet, repeat=len(idx)):
      neighbor = list(x[:])
      for i, j in zip(idx, idx_j):
        neighbor[i] = j
      neighbor = "".join(neighbor)
      if neighbor != x:
          neighbors.add(neighbor)
  return neighbors   

def number_of_modes(args):
  print('Running evaluation RNA ...')

  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
    actorclass = seqpamdp.SeqPAActor
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
    actorclass = seqinsertmdp.SeqInsertActor
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
    actorclass = seqarmdp.SeqARActor
  mdp = dynamic_inherit_mdp(base, args)

  # load model checkpoint
  ckpt_path = args.saved_models_dir + args.run_name
  with open(ckpt_path + '/' + f"final_sample.pkl", "rb") as f:
    generated_samples = pickle.load(f)
    
  with open(args.saved_models_dir + "mode_info.pkl", "rb") as f:
    mode_info = pickle.load(f)
  
  unique_samples = set()
  batch_size = args.num_samples_per_online_batch
  number_of_modes = {k: np.zeros((len(generated_samples) // batch_size, )) for k in mode_info}
  with tqdm(total=len(generated_samples)) as pbar:
    for i in range(0, len(generated_samples), batch_size):
      for exp in generated_samples[i: i+batch_size]:
        if exp.x not in unique_samples:      
          if exp.x.content in mode_info["modes"]:
            number_of_modes["modes"][i // batch_size] += 1
          if exp.x.content in mode_info["modes_hamming_ball1"]:
            number_of_modes["modes_hamming_ball1"][i // batch_size] += 1
          if exp.x.content in mode_info["modes_hamming_ball2"]:
            number_of_modes["modes_hamming_ball2"][i // batch_size] += 1
          unique_samples.add(exp.x)
      pbar.update(batch_size)
      pbar.set_postfix(number_of_modes=np.sum(number_of_modes["modes"]))
  print(np.sum(number_of_modes["modes"]))
  np.savez_compressed(ckpt_path + '/' + f'number_of_modes_updated.npz', modes=number_of_modes["modes"],
                                                                        modes_hamming_ball1=number_of_modes["modes_hamming_ball1"],
                                                                        modes_hamming_ball2=number_of_modes["modes_hamming_ball2"]) 

def mode_ablations(args):
  print('Running evaluation RNA ...')

  if args.mdp_style == 'pa':
    base = seqpamdp.SeqPrependAppendMDP
  elif args.mdp_style == 'insert':
    base = seqinsertmdp.SeqInsertMDP
  elif args.mdp_style == 'autoregressive':
    base = seqarmdp.SeqAutoregressiveMDP
  mdp = dynamic_inherit_mdp(base, args)
  
  print("Loading Dataset...")
  with open(f"exps/rna/reward_L{args.rna_length}_RNA{args.rna_task}_final.pkl", "rb") as f:
    x_to_r = pickle.load(f)
    
  rewards = np.array(list(x_to_r.values()))
  
  mode_percentile = 0.005
  mode_r_threshold = np.percentile(rewards, 100*(1-mode_percentile))
  
  print("Find Original Modes...")
  original_modes = list(filter(lambda x: x[1] > mode_r_threshold, [(x, r) for x, r in x_to_r.items()]))
  
  # Sort modes with neighborhood local peak
  modes_hamming_ball1 = set()
  modes_hamming_ball2 = set()
  # modes_hamming_ball3 = set()
  
  with tqdm(total=len(original_modes)) as pbar:
    for x, r in original_modes:
      neighbor_x_hamming_ball1 = get_neighbors(x, mdp, hamming_ball=1)
      neighbor_x_hamming_ball2 = get_neighbors(x, mdp, hamming_ball=2)
      # neighbor_x_hamming_ball3 = get_neighbors(x, mdp, hamming_ball=3)
      
      max_neighbor_r_hamming_ball1 = max([x_to_r[neighbor_x] for neighbor_x in neighbor_x_hamming_ball1])
      max_neighbor_r_hamming_ball2 = max([x_to_r[neighbor_x] for neighbor_x in neighbor_x_hamming_ball2])
      # max_neighbor_r_hamming_ball3 = max([x_to_r[neighbor_x] for neighbor_x in neighbor_x_hamming_ball3])
      
      if r > max_neighbor_r_hamming_ball1:
        modes_hamming_ball1.add(x)
      if r > max_neighbor_r_hamming_ball2:
        modes_hamming_ball2.add(x)
      # if r > max_neighbor_r_hamming_ball3:
      #   modes_hamming_ball3.add(x)
        
      pbar.update(1)
      
  print("Mode Info")
  print(f"Original Num Modes. {len(original_modes)}")
  print(f"Filtering via Neighbors (range=1). {len(modes_hamming_ball1)}")
  print(f"Filtering via Neighbors (range=2). {len(modes_hamming_ball2)}")
  # print(f"Filtering via Neighbors (range=3). {len(modes_hamming_ball3)}")

  mode_info = {"modes": set([x for x, _ in original_modes]),
               "modes_hamming_ball1": modes_hamming_ball1,
               "modes_hamming_ball2": modes_hamming_ball2,}
              #  "modes_hamming_ball3": modes_hamming_ball3,}
  
  with open(args.saved_models_dir + "mode_info.pkl", "wb") as f:
    pickle.dump(mode_info, f)
        
        
        
