import os
import time
import math
import pickle
import numpy as np
import torch
from gfn_base import GFlowNetBase
from torch.nn.utils.rnn import pad_sequence

# GFlowNET for sparse reward envs using trajectory balance loss 
class GFlowNet(GFlowNetBase):
      def train(self):
            # First epsilon * total step rounds for RL update, we use the basic PPO update with default hyper-parameters:
            if self.sample_method == 1:
                  self.memory.update_threshold(self.batch_size)
            
            traj_losses = 0
            traj_losses_std = 0
            # pessimistic updates of PB for N times, N = self.pessimistic_updates
            for pessimistic_step in range(self.pessimistic_updates):
                  batch_obs, batch_acts, batch_next_obs, batch_rews = self.memory.latest_sample(self.batch_size, min((pessimistic_step + 1) * self.batch_size, self.memory.num_elements))
                  lengths = torch.tensor(np.array([len(obs) for obs in batch_obs]), device = self.device)
                  batch_obs_pad = pad_sequence(batch_obs, batch_first = True)
                  batch_acts_pad = pad_sequence(batch_acts, batch_first = True)
                  batch_next_obs_pad = pad_sequence(batch_next_obs, batch_first = True)
                  if self.continuous:
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
                  else:
                        backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

                  logPB = torch.sum(logPB, dim = 1)
                  pessimistic_loss = -logPB.mean()
                  self.backward_optim.zero_grad()
                  pessimistic_loss.backward()
                  self.backward_optim.step()
                  
            for gradient_step in range(self.gradient_steps):       
                  #  Sample from the rollout buffer
                  if self.sample_method == 1:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.biased_sample(self.batch_size)
                  elif self.sample_method == 2:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.generalized_biased_sample(self.batch_size)
                  elif self.sample_method == 3:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.mixed_priority_sample(self.batch_size)
                  else:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.sample(self.batch_size)
                  
                  # Efficient version
                  batch_obs_pad = pad_sequence(batch_obs, batch_first = True)
                  batch_acts_pad = pad_sequence(batch_acts, batch_first = True)
                  batch_next_obs_pad = pad_sequence(batch_next_obs, batch_first = True)

                  lengths = torch.tensor(np.array([len(obs) for obs in batch_obs]), device = self.device)

                  # Forward pass
                  if self.continuous:
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, self.env.unwrapped.max_t, lengths = lengths, use_mask = self.timeout_mask)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
                  else:
                        forward_mask = self.env.unwrapped.get_forward_action_masks(batch_obs_pad)
                        backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, forward_mask, lengths = lengths)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

                  logPF = torch.sum(logPF, dim = 1)
                  logPB = torch.sum(logPB, dim = 1)

                  batch_rews = torch.tensor(batch_rews, device = self.device)
                  batch_augmented_rews = torch.tensor(batch_augmented_rews, device = self.device)

                  if self.multiply_temperature:
                        log_reward = torch.log(batch_rews)/self.temperature
                  else:
                        log_reward = torch.log(batch_rews + self.temperature * batch_augmented_rews)

                  if torch.any(torch.isinf(logPF)) or torch.any(torch.isinf(logPB)):
                        raise ValueError("Infinite logprobs found")
                  
                  loss = (logPF + self.logZ - logPB - log_reward).pow(2)

                  if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                        raise ValueError(f"Invalid loss found, loss: {loss}")
                  
                  traj_losses += loss.mean().detach().item()
                  traj_losses_std += loss.std().detach().item()

                  loss_detached = loss.detach().cpu().numpy()

                  if self.sample_method >= 2:
                        self.memory.update_priority(batch_idx, loss_detached)

                  
                  self.memory.push_train_logs(logPF.detach().cpu().numpy(), logPB.detach().cpu().numpy(), loss_detached, batch_rews.cpu().numpy(), self.logZ.item())

                  if self.use_filter:
                        batch_norm = - logPF.detach() - logPB.detach() + log_reward.detach()
                        
                        # rew_filter = batch_rews.squeeze() > batch_rews.mean() # high reward records, since there is a vast space with nearly 0 reward (no hope to sufficiently explore all of them to get their PB estimation)

                        # if torch.sum(rew_filter) > 1:
                        filter_norm = batch_norm #[rew_filter]
                        batch_mean = filter_norm.mean()
                        batch_std = filter_norm.std() 

                        batch_filter1 = (batch_norm > (batch_mean + self.filter_upper * batch_std)) #& rew_filter # when this to be true, this is an underexplored high reward trajectory 
                  
                        if(torch.any(batch_filter1) and self.verbose):
                              print("Will promote:")
                              print("reward:", batch_rews[batch_filter1])
                              print("logPF:", logPF[batch_filter1])
                              print("logPB:", logPB[batch_filter1])
                              print("This", batch_norm[batch_filter1], "Mean:", batch_mean, "Std:", batch_std)

                        loss = loss - (logPF + logPB) * (batch_filter1) 
                        
                        batch_filter2 =  (batch_norm < (batch_mean - self.filter_lower * batch_std)) #& rew_filter # when this to be true, this is a relatively old trajectory in a frequently visited area 
                        
                        if(torch.any(batch_filter2) and self.verbose):
                              print("Will depress:")
                              print("reward:", batch_rews[batch_filter2])
                              print("logPF:", logPF[batch_filter2])
                              print("logPB:", logPB[batch_filter2])
                              print("This", batch_norm[batch_filter2], "Mean:", batch_mean, "Std:", batch_std)

                              # loss = loss * (~batch_filter2) - (logPB - logPF) * (batch_filter2) # release the PF (IMO this is the exploration budget) and update the PB by the maximum likelihood

                  # Compute traj loss
                  loss = loss.mean()
                  # print(loss)

                  # Optimize the models
                  self.forward_optim.zero_grad()
                  self.backward_optim.zero_grad()
                  self.logZ_optim.zero_grad()

                  loss.backward()

                  # clip gradients
                  # torch.nn.utils.clip_grad_norm_(self.forward_policy.parameters(), 1e1)
                  # torch.nn.utils.clip_grad_norm_(self.backward_policy.parameters(), 1e1)
                  # torch.nn.utils.clip_grad_norm_(self.logZ, 1e1)

                  self.forward_optim.step()
                  self.backward_optim.step()
                  self.logZ_optim.step()

                  self.logger['log_Z'].append(self.logZ.item())

            self._n_updates += gradient_step
            # Log actor loss
            self.logger['traj_losses'].append(traj_losses/self.gradient_steps)
            self.logger['traj_losses_std'].append(traj_losses_std/self.gradient_steps)
      
      def save(self, model_dir):
            if not os.path.exists(model_dir):
                  os.makedirs(model_dir)
            torch.save(self.forward_policy.state_dict(), f'{model_dir}/forward_policy.pth')
            torch.save(self.backward_policy.state_dict(), f'{model_dir}/backward_policy.pth')
            torch.save(self.logZ, f'{model_dir}/logZ.pth')
            # save the optimizer
            torch.save(self.forward_optim.state_dict(), f'{model_dir}/forward_optim.pth')
            torch.save(self.backward_optim.state_dict(), f'{model_dir}/backward_optim.pth')
            torch.save(self.logZ_optim.state_dict(), f'{model_dir}/logZ_optim.pth')
            # save the i_so_far, e_so_far, t_so_far
            current_progress = (self.logger['t_so_far'], self.logger['i_so_far'], self.logger['e_so_far'])
            with open(f'{model_dir}/progress.pkl', 'wb') as f:
                  pickle.dump(current_progress, f)

      def save_replay_buffer(self, model_dir):
            # Save our model and memory at the end of training
            self.memory.save(model_dir)

      def load(self, model_dir, load_optim = False):
            # Load our model 
            self.forward_policy.load_state_dict(torch.load(f'{model_dir}/forward_policy.pth'))
            self.backward_policy.load_state_dict(torch.load(f'{model_dir}/backward_policy.pth'))
            self.logZ = torch.load(f'{model_dir}/logZ.pth')

            if load_optim:
                  self.forward_optim.load_state_dict(torch.load(f'{model_dir}/forward_optim.pth'))
                  self.backward_optim.load_state_dict(torch.load(f'{model_dir}/backward_optim.pth'))
                  self.logZ_optim.load_state_dict(torch.load(f'{model_dir}/logZ_optim.pth'))

                  # load the i_so_far, e_so_far, t_so_far
                  with open(f'{model_dir}/progress.pkl', 'rb') as f:
                        t_so_far, i_so_far, e_so_far = pickle.load(f)
            
                  return t_so_far, i_so_far, e_so_far
            return 0, 0, 0
      
      def load_replay_buffer(self, model_dir):
            # Load the memory
            self.memory.load(model_dir)