import os
import time
import math
import pickle
import numpy as np
import torch
from torch.optim import Adam

from nns import ContinuousBackwardPolicy, ContinuousForwardPolicy, DiscreteForwardPolicy, DiscreteBackwardPolicy
from memory import Memory

from abc import abstractmethod
from torch.utils.tensorboard import SummaryWriter

class GFlowNetBase:
      def __init__(self, env, \
                   learning_rate = 1e-3, batch_size = 32, buffer_size = 10000,\
                   train_freq=16, gradient_steps = 10, learning_starts = 100, \
                   explorative_policy = None, explorative_num = 100,\
                   temperature = 1, \
                   sample_method = 0, \
                   use_filter = False,\
                   device = 'auto', continuous = True, tensorboard_log = None, verbose = False,\
                   hidden_sizes = [256, 256], \
                   activation_fn = torch.nn.ReLU,\
                   initial_z = 0.0,\
                   num_val_samples=0,\
                   pessimistic_updates = 0, \
                   model_dir=None, \
                   validation_env=None, \
                   data_env=None, \
                   no_decay = False,\
                   timeout_mask = False,\
                   filter_upper = 3,\
                   filter_lower = 2,\
                   epsilon_random = 0.1,
                   temperature_rate = 40,
                   multiply_temperature = False
                   ):
          
            # initialize the hyperparameters
            self.learning_rate = learning_rate
            self.buffer_size = buffer_size
            self.batch_size = batch_size
            self.temperature = temperature
            self.temperature_init = temperature
            self.gradient_steps = gradient_steps
            self.learning_starts = learning_starts
            self.explorative_num = explorative_num
            self.train_freq = train_freq
            self.sample_method = sample_method
            self.device = device
            self.verbose = verbose
            self.use_filter = use_filter
            self.num_val_samples = num_val_samples
            self.pessimistic_updates = pessimistic_updates
            self.model_dir = model_dir
            self.validation_env = validation_env
            self.data_env = data_env
            self.no_decay = no_decay
            self.timeout_mask = timeout_mask
            self.filter_upper = filter_upper
            self.filter_lower = filter_lower
            self.epsilon_random = epsilon_random
            self.epsilon_random_init = epsilon_random
            self.temperature_rate = temperature_rate
            self.multiply_temperature = multiply_temperature

            if self.device == 'auto':
                  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

            self.state_dim = env.observation_space.shape[0]
            self.continuous = continuous

            if continuous:
                  self.action_dim = env.action_space.shape[0]
                  self.forward_policy = ContinuousForwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)
                  self.backward_policy = ContinuousBackwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)
            else:
                  self.action_dim = env.action_space.n
                  self.forward_policy = DiscreteForwardPolicy(self.state_dim, self.action_dim, hidden_sizes, activation_fn, device = self.device).to(self.device)
                  self.backward_policy = DiscreteBackwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)

            self.logZ = torch.tensor(initial_z, requires_grad = True, device = self.device, dtype=torch.float32)

            self.memory = Memory(self.buffer_size, device = self.device)

            self.explorative_policy = explorative_policy # Note: the input and output should be the same as the forward policy 

            self.forward_optim = Adam(self.forward_policy.parameters(), lr=self.learning_rate)
            self.backward_optim = Adam(self.backward_policy.parameters(), lr=self.learning_rate)
            self.logZ_optim = Adam([self.logZ], lr=self.learning_rate * 100) # According to the original paper, the learning rate of logZ is greater than the forward and backward policy

            self.forward_optim.param_groups[0]['initial_lr'] = self.learning_rate
            self.backward_optim.param_groups[0]['initial_lr'] = self.learning_rate
            self.logZ_optim.param_groups[0]['initial_lr'] = self.learning_rate * 100

            self.forward_optim.param_groups[0]['min_lr'] = self.learning_rate/100
            self.backward_optim.param_groups[0]['min_lr'] = self.learning_rate/100
            self.logZ_optim.param_groups[0]['min_lr'] = self.learning_rate

            self._n_updates = 0

            self.logger = {
                  'delta_t': time.time_ns(),
                  't_so_far': 0,          # timesteps so far
                  'i_so_far': 0,          # iterations so far
                  'e_so_far': 0,          # episodes so far
                  'batch_lens': [],       # episodic lengths in batch
                  'batch_rews': [],       # episodic returns in batch
                  'traj_losses': [],     # trajectory balance losses in current iteration
                  'traj_losses_std': [], # trajectory balance losses std in current iteration
                  'log_Z': [],            # log_Z in current iteration
            }
            
            self.env = env

            self.writer = SummaryWriter(tensorboard_log)
      
      # single env version
      def collect_rollouts(self):
            buffer_lens = []
            buffer_rews_report = []

            num_collected_steps, num_collected_episodes = 0, 0

            with torch.no_grad():
                  while num_collected_episodes < self.train_freq:
                        ep_acts = []
                        ep_obs = []
                        ep_next_obs = []
                        ep_rews = []

                        # Reset the environment. sNote that obs is short for observation. 
                        obs, _ = self.env.reset()
                        done = False
                        ep_t = 0

                        while not done:
                              # Track observations in this buffer
                              ep_obs.append(obs)
                              # Calculate action and make a step in the env. 
                              # Note that rew is short for reward.
                              if self.continuous:
                                    action = self.predict(obs, reach_end = np.array(ep_t >= self.env.unwrapped.max_t - 1), use_mask=self.timeout_mask)
                              else:
                                    action = self.predict(obs, action_mask = self.env.unwrapped.get_forward_action_masks(obs))
                              # print("Before: ", obs)
                              obs, rew, done, truncated, augmented_rew = self.env.step(action)
                              
                              # print("After: ", obs)
                              # Track recent action, and action log probability
                              ep_acts.append(action)
                              ep_next_obs.append(obs)
                              ep_rews.append(rew)

                              ep_t += 1

                        if not truncated:
                              num_collected_steps += len(ep_obs)
                              num_collected_episodes += 1

                              buffer_lens.append(len(ep_obs))
                              buffer_rews_report.append(ep_rews[-1])

                              ep_rtgs = ep_rews.copy()
                              augmented_rew = augmented_rew['augmented_rew']
                              # add to the memory
                              self.memory.push_traj(ep_obs, ep_acts, ep_next_obs, ep_rtgs, augmented_rew)

            # Log the episodic returns and episodic lengths in this buffer.
            self.logger['buffer_rews'] = buffer_rews_report
            self.logger['buffer_lens'] = buffer_lens

            return num_collected_steps, num_collected_episodes
      
      # parallel version
      def collect_rollouts_parallel(self):
            buffer_rews_report = np.zeros(self.train_freq) 
            num_collected_steps, num_collected_episodes = 0, 0
            buffer_lens = np.zeros(self.train_freq, dtype = int)
            with torch.no_grad():
                  dones = np.zeros(self.train_freq, dtype = bool)
                  obs_buffer = []
                  action_buffer = []
                  next_obs_buffer = []
                  reward_buffer = []
                  augmented_rew_buffer = np.ones(self.train_freq, dtype = np.float32)
                  obs, _ = self.data_env.reset()

                  prev_obs = obs
                  ep_t = 0
                  while not np.all(dones):
                        obs_buffer.append(obs)
                        if self.continuous:
                              action = self.predict(obs, reach_end =  (ep_t >= self.env.unwrapped.max_t - 1) * np.ones(self.train_freq) * (1-dones), use_mask=self.timeout_mask)
                        else:
                              action = self.predict(obs, action_mask = self.env.unwrapped.get_forward_action_masks(obs))
                        
                        prev_obs = obs
                        obs, rew, d, t, augmented_rew = self.data_env.step(action) # if done, automatically reset
                        
                        action_buffer.append(action)
                        reward_buffer.append(rew * (~dones))
                        obs[d, :] = prev_obs[d, :]
                        next_obs_buffer.append(obs)

                        mask = d & (~dones) & (~t)
                        if np.any(mask):
                              buffer_lens[mask] = len(obs_buffer)
                              buffer_rews_report[mask] = reward_buffer[-1][mask] 
                              augmented_rew_values = np.array([entry['augmented_rew'] 
                                    for entry in augmented_rew['final_info'][mask]])
                              augmented_rew_buffer[mask] = augmented_rew_values
                              dones[mask] = True

                        num_collected_steps += np.sum(~dones)
                        ep_t += 1

                  num_collected_episodes += np.sum(dones)

                  self.memory.push_trajs(obs_buffer, action_buffer, next_obs_buffer, reward_buffer, buffer_lens, augmented_rew_buffer)

            self.logger['buffer_rews'] = buffer_rews_report
            self.logger['buffer_lens'] = buffer_lens
            
            return num_collected_steps, num_collected_episodes
      
      # unused, for sampling transitions for RL updates
      def get_rollouts(self, buffer_obs, buffer_acts, buffer_next_obs, buffer_rtgs, buffer_log_probs, buffer_log_probs2):
            buffer_size = len(buffer_obs)
            indices = np.random.permutation(buffer_size)

            start_idx = 0
            while start_idx + self.batch_size < buffer_size:
                  idx = indices[start_idx: start_idx + self.batch_size]
                  yield buffer_obs[idx], buffer_acts[idx], buffer_next_obs[idx], buffer_rtgs[idx], buffer_log_probs[idx], buffer_log_probs2[idx]
                  start_idx += self.batch_size
            if start_idx < buffer_size:
                  idx = indices[start_idx:]
                  yield buffer_obs[idx], buffer_acts[idx], buffer_next_obs[idx], buffer_rtgs[idx], buffer_log_probs[idx], buffer_log_probs2[idx]

      
      # off-policy learning
      def learn(self, total_iterations, t_start = 0, i_start = 0, e_start = 0):
            if self.verbose:
                  print(f"Learning... Using {self.batch_size} episodes per updates, ", end='')
                  print(f"at least {self.train_freq} episodes per update for {total_iterations} iterations.")
            t_so_far = t_start # Timesteps simulated so far
            i_so_far = i_start # Iterations ran so far
            e_so_far = e_start # Episodes simulated so far

            # collect data from the explorative policy
            if self.explorative_policy is not None:
                  explorative_e_so_far = 0
                  while explorative_e_so_far < self.explorative_num:
                        tmp_t, tmp_e = self.collect_rollouts_from_explorative_policy()
                        explorative_e_so_far += tmp_e

                  # Print a summary of the collected data
                  self._log_summary()

            if self.explorative_policy is None and self.explorative_num > 0:
                  print("Warning: Explorative policy is not provided, collect data from random policy.")

            val_errs = []
            while i_so_far < total_iterations:
                  # increment the number of epoches
                  if self.data_env is not None:
                        tmp_t, tmp_e = self.collect_rollouts_parallel()
                  else:
                        tmp_t, tmp_e = self.collect_rollouts()
                  e_so_far += tmp_e

                  # Logging timesteps so far and iterations so far
                  self.logger['e_so_far'] = e_so_far
                  # print(e_so_far)

                  if len(self.memory) >= self.batch_size and e_so_far >= self.learning_starts:
                        self.train()
                        i_so_far += 1

                        t_so_far += tmp_t
                        self.logger['t_so_far'] = t_so_far
                        self.logger['i_so_far'] = i_so_far

                        # Update optimizers learning rate with linear decay
                        gfn_optimizers = [] # self.logZ_optim, self.forward_optim, self.backward_optim
                        # Update learning rate according to lr schedule
                        self._update_learning_rate(gfn_optimizers, i_so_far, total_iterations)

                        # Print a summary of our training so far
                        self._log_summary()

                  if self.num_val_samples != 0 and i_so_far % 100 == 0:
                        # print("Validating")
                        if self.validation_env is not None:
                              samples = []
                              with torch.no_grad():
                                    # compute the error by generating some trajectories with the current model
                                    ep_t = 0
                                    s, _ = self.validation_env.reset()
                                    dones = np.zeros((s.shape[0],), dtype = bool)
                                    while(len(samples) < self.num_val_samples):
                                          if self.continuous:
                                                a = self.predict(s, reach_end = (ep_t >= self.env.unwrapped.max_t - 1) * np.ones(s.shape[0]) * (1-dones), use_mask = self.timeout_mask)
                                          else:
                                                # we use the original env to get the action mask
                                                action_mask = self.env.unwrapped.get_forward_action_masks(s)
                                                a = self.predict(s, action_mask = action_mask)

                                          prev_s = s
                                          s, _, d, _, _ = self.validation_env.step(a) # will automatically reset if done, use prev_s
                                          if np.any(d & (~dones)):
                                                samples +=  list(self.env.unwrapped.get_state(prev_s[d&(~dones), :]))
                                                prev_s[d, :] = s[d, :]
                                                # stop take new samples if we have enough
                                                remaining_samples = self.num_val_samples - len(samples) - np.sum(~dones)
                                                if remaining_samples < 0:
                                                      indices_to_mark = np.where(d&(~dones))[0][: -remaining_samples]
                                                      dones[indices_to_mark] = True
                                          ep_t += 1
                              samples = samples[-self.num_val_samples:]
                        else:
                              samples = self.env.unwrapped.get_state(np.array(self.memory.visited_end_states[-min(self.num_val_samples, len(self.memory.visited_end_states)):]))
                        val_err = self.env.unwrapped.get_error(samples)
                        print(f"Temperate: = {self.temperature}")
                        print(f"Validation error = {val_err}")
                        val_errs.append(val_err)
            # save val_errs
            with open(os.path.join(self.model_dir, "val_errs.pkl"), 'wb') as f:
                  pickle.dump(val_errs, f)

      def _update_learning_rate(self, gfn_optimizers, i_so_far, total_iterations):
            if not isinstance(gfn_optimizers, list):
                  gfn_optimizers = [gfn_optimizers]

            discount_ratio_gfn = 1 - i_so_far / (total_iterations + 1e-8)

            if self.epsilon_random_init > 0:
                  self.epsilon_random = discount_ratio_gfn * self.epsilon_random_init

            if self.temperature_init > 0: # diminishing the learning rate, unused in the paper
                  for optimizer in gfn_optimizers:
                        for param_group in optimizer.param_groups:
                              param_group['lr'] = max(param_group['min_lr'], param_group['initial_lr'] * discount_ratio_gfn)

            # update temperature
            if not self.no_decay:
                  if self.multiply_temperature:
                        self.temperature = (self.temperature_init - 1) * discount_ratio_gfn + 1
                  else:
                        self.temperature = max(self.temperature_init * 1/(1+ math.pow(10, 10 - self.temperature_rate * discount_ratio_gfn)), 0)

      def predict(self, obs, reach_end = None, action_mask= None, use_mask = False): 
            # Query the forward poilcy for an action
            if self.continuous:
                  action = self.forward_policy(obs, reach_end, use_mask, epsilon = self.epsilon_random)
            else:
                  action = self.forward_policy(obs, action_mask, epsilon = self.epsilon_random)

            # Return the action
            return action.detach().cpu().numpy()
      
      def _log_summary(self):
            """
                  Print to stdout what we've logged so far in the most recent batch.

                  Parameters:
                        None

                  Return:
                        None
            """
            # Calculate logging values. I use a few python shortcuts to calculate each value
            # without explaining since it's not too important to PPO; feel free to look it over,
            # and if you have any questions you can email me (look at bottom of README)
            delta_t = self.logger['delta_t']
            self.logger['delta_t'] = time.time_ns()
            delta_t = (self.logger['delta_t'] - delta_t) / 1e9
            delta_t = round(delta_t, 2)

            t_so_far = self.logger['t_so_far']
            i_so_far = self.logger['i_so_far']
            e_so_far = self.logger['e_so_far']
            avg_ep_lens = np.mean(self.logger['buffer_lens'])
            avg_ep_rews = np.mean(self.logger['buffer_rews'])
            avg_traj_loss = np.mean(self.logger['traj_losses'])
            avg_traj_loss_std = np.mean(self.logger['traj_losses_std'])
            avg_log_Z = np.mean(self.logger['log_Z'])

            # Round decimal places for more aesthetic logging messages
            avg_ep_lens = round(avg_ep_lens, 2)
            avg_ep_rews = round(avg_ep_rews, 2)
            avg_traj_loss = round(avg_traj_loss, 5)
            avg_traj_loss_std = round(avg_traj_loss_std, 5)
            avg_log_Z = round(avg_log_Z, 5)

            # Print logging statements
            if self.verbose:
                  print(flush=True)
                  print(f"-------------------- Epoch #{i_so_far} --------------------", flush=True)
                  print(f"Average Episodic Length: {avg_ep_lens}", flush=True)
                  print(f"Average Episodic Return: {avg_ep_rews}", flush=True)
                  print(f"Average Traj Loss: {avg_traj_loss:.5f}", flush=True)
                  print(f"Average Traj Loss Std: {avg_traj_loss_std:.5f}", flush=True)
                  print(f"Average log_Z: {avg_log_Z:.5f}", flush=True)

                  print(f"Timesteps So Far: {t_so_far}", flush=True)
                  print(f"Episode So Far: {e_so_far}", flush=True)
                  print(f"Iteration took: {delta_t} secs", flush=True)
                  print(f"------------------------------------------------------", flush=True)
                  print(flush=True)

            self.writer.add_scalar('train/episodic_length', avg_ep_lens, i_so_far)
            self.writer.add_scalar('train/episodic_return', avg_ep_rews, i_so_far)
            self.writer.add_scalar('train/traj_loss', avg_traj_loss, i_so_far)
            self.writer.add_scalar('train/traj_loss_std', avg_traj_loss_std, i_so_far)
            self.writer.add_scalar('train/log_Z', avg_log_Z, i_so_far)

            # Reset batch-specific logging data
            self.logger['batch_lens'] = []
            self.logger['batch_rews'] = []
            self.logger['traj_losses'] = []
            self.logger['traj_losses_std'] = []

      @abstractmethod
      def train(self):
            pass

      @abstractmethod
      def save(self, model_dir):
            pass

      @abstractmethod
      def save_replay_buffer(self, model_dir):
            pass

      @abstractmethod
      def load(self, model_dir, load_optim = False):
            pass

      @abstractmethod
      def load_replay_buffer(self, model_dir):
            pass

      