import random, time
import numpy as np
import torch
from tqdm import tqdm
import ray
import math
import pickle
import os, signal
from datetime import datetime
from .data import Experience
    

class Trainer:
  def __init__(self, args, mdp):
    self.args = args
    self.mdp = mdp
    self.callbackXtoR = dict()
      
    with open(args.all_dataset, 'rb') as f:
      all_datasets = pickle.load(f)
    
    keys = [self.mdp.state(k, True) if not hasattr(k, "content") else k for k in all_datasets.keys()]
    values = np.array(list(all_datasets.values()))
    mean = np.mean(values)
    std = np.std(values)
    scaled_rewards = np.exp((values - mean) / std)
    self.true_datasets = dict(zip(keys, scaled_rewards))

    
class TrainerGFN(Trainer):
  def __init__(self, args, model, mdp, actor, monitor):
    super().__init__(args, mdp)
    self.model = model
    self.actor = actor
    self.monitor = monitor
        
  """
    Training
  """
  def learn(self):
    num_online = self.args.num_online_batches_per_round
    online_bsize = self.args.num_samples_per_online_batch
    monitor_fast_every = self.args.monitor_fast_every
    monitor_num_samples = self.args.monitor_num_samples

    for round_num in tqdm(range(self.args.num_active_learning_rounds)):
      print(f'Starting learning round {round_num+1} / {self.args.num_active_learning_rounds} ...')
        
      # Online training
      for _ in range(num_online):
          with torch.no_grad():
              explore_data = self.model.batch_fwd_sample(online_bsize, epsilon=self.args.explore_epsilon)
          
          for step_num in range(self.args.num_steps_per_batch):
              self.model.train(explore_data)
        
      if round_num % monitor_fast_every == 0 and round_num > 0:
        truepolicy_data = self.model.batch_fwd_sample(monitor_num_samples, epsilon=0)
        truepolicy_data = [
            exp._replace(
                r = self.true_datasets[exp.x],
                logr = np.log(self.true_datasets[exp.x])
            )
            for exp in truepolicy_data
        ]
            
        self.monitor.log_samples(round_num, truepolicy_data)
          
      self.monitor.maybe_eval_samplelog(self.model, round_num, self.callbackXtoR)

      if round_num % self.args.save_every_x_active_rounds == 0:
        if round_num > 0:
          self.model.save_params(self.args.saved_models_dir + \
                                 self.args.run_name + \
                                 f'_round_{round_num}.pth')
        
    print('Finished training.')
    self.model.save_params(self.args.saved_models_dir + \
                           self.args.run_name + '_final.pth')
    self.monitor.maybe_eval_samplelog(self.model, round_num, self.callbackXtoR)
    return


class TrainerIL(Trainer):
  def __init__(self, args, model, mdp, actor, monitor):
    super().__init__(args, mdp)
    self.model = model
    self.actor = actor
    self.monitor = monitor

  """
    Training
  """
  def learn(self, initial_XtoR=None):
    allXtoR = initial_XtoR if initial_XtoR else self.mdp.initial_XtoR
    num_online = self.args.num_online_batches_per_round
    num_offline = self.args.num_offline_batches_per_round
    online_bsize = self.args.num_samples_per_online_batch
    offline_bsize = self.args.num_samples_per_offline_batch
    monitor_fast_every = self.args.monitor_fast_every
    monitor_num_samples = self.args.monitor_num_samples
    print(f'Starting active learning. \
            Each round: {num_online=}, {num_offline=}')

    for round_num in tqdm(range(self.args.num_active_learning_rounds)):
      print(f'Starting learning round {round_num+1} / {self.args.num_active_learning_rounds} ...')
        
      offline_batch = []
      for _ in range(num_offline):
        offline_xs = self.select_offline_xs(allXtoR, offline_bsize)
        offline_dataset = self.offline_PB_traj_sample(offline_xs, allXtoR)
        offline_batch.extend(offline_dataset)
      online_batch = []
      for _ in range(num_online):
          # Sample new dataset
        with torch.no_grad():
          explore_data = self.model.batch_fwd_sample(online_bsize, epsilon=self.args.explore_epsilon)
          online_batch.extend(explore_data)
      # Train on online dataset
      combined = offline_batch + online_batch
      for step_num in range(self.args.num_steps_per_batch):
          self.model.train(combined)

      if round_num % monitor_fast_every == 0 and round_num > 0:
        truepolicy_data = self.model.batch_fwd_sample(monitor_num_samples, epsilon=0)
        truepolicy_data = [
            exp._replace(
                r = self.true_datasets[exp.x],
                logr = np.log(self.true_datasets[exp.x])
            )
            for exp in truepolicy_data
        ]
            
        self.monitor.log_samples(round_num, truepolicy_data)
      self.monitor.maybe_eval_samplelog(self.model, round_num, self.callbackXtoR)

      if round_num % self.args.save_every_x_active_rounds == 0:
        if round_num > 0:
          self.model.save_params(self.args.saved_models_dir + \
                                 self.args.run_name + \
                                 f'_round_{round_num}.pth')

    print('Finished training.')
    self.model.save_params(self.args.saved_models_dir + \
                           self.args.run_name + '_final.pth')
    self.monitor.maybe_eval_samplelog(self.model, round_num, self.callbackXtoR)
    return

  
  """
    Offline training
  """
  def select_offline_xs(self, allXtoR, batch_size):
    select = self.args.get('offline_select', 'biased')
    if select == 'biased':
      return self.__biased_sample_xs(allXtoR, batch_size)
    elif select == 'random':
      return self.__random_sample_xs(allXtoR, batch_size)

  def __biased_sample_xs(self, allXtoR, batch_size):
    """ Select xs for offline training. Returns List of [State].
        Draws 50% from top 10% of rewards, and 50% from bottom 90%. 
    """
    if len(allXtoR) < 10:
      return []
    rewards = np.array(list(allXtoR.values()))
    threshold = np.percentile(rewards, 90)
    top_xs = [x for x, r in allXtoR.items() if r >= threshold]
    bottom_xs = [x for x, r in allXtoR.items() if r <= threshold]
    sampled_xs = random.choices(top_xs, k=batch_size // 2) + \
                 random.choices(bottom_xs, k=batch_size // 2)
    return sampled_xs

  def __random_sample_xs(self, allXtoR, batch_size):
    """ Select xs for offline training. Returns List of [State]. """
    return random.choices(list(allXtoR.keys()), k=batch_size)

  def offline_PB_traj_sample(self, offline_xs, allXtoR):
    """ Sample trajectories for x using P_B, for offline training with TB.
        Returns List of [Experience].
    """
    offline_rs = [allXtoR[x] for x in offline_xs]

    # Not subgfn: sample trajectories from backward policy
    print(f'Sampling trajectories from backward policy ...')
    with torch.no_grad():
      offline_trajs = self.model.batch_back_sample(offline_xs)

    offline_dataset = [
      Experience(traj=traj, x=x, r=r,
                  logr=torch.log(torch.tensor(r, device=self.args.device, dtype=torch.float32)))
      for traj, x, r in zip(offline_trajs, offline_xs, offline_rs)
    ]
    return offline_dataset