import os
import time
import math
import pickle
from typing_extensions import override
import numpy as np
import torch
from gfn_base import GFlowNetBase
from torch.nn.utils.rnn import pad_sequence
from nns import ContinuousForwardPolicy, ContinuousBackwardPolicy, DiscreteForwardPolicy, DiscreteBackwardPolicy
from torch.optim import Adam

# GFlowNET for sparse reward envs using trajectory balance loss 
class AdaptiveTeacherGFlowNet(GFlowNetBase):
      def __init__(
        self, env, learning_rate = 1e-3, batch_size = 32, buffer_size = 10000,\
        train_freq=16, gradient_steps = 10, learning_starts = 100, \
        temperature = 1,\
        sample_method = 0, 
        use_filter = False, \
        weighting = "geometricwithin", \
        lamda: float = 0.9, \
        device = 'auto', continuous = True, tensorboard_log = None, verbose = False,\
        hidden_sizes = [256, 256], \
        activation_fn = torch.nn.ReLU,\
        initial_z = 0.0,\
        num_val_samples=0,\
        pessimistic_updates = 0, \
        model_dir=None, \
        validation_env=None, \
        data_env=None, \
        no_decay = False,\
        timeout_mask = False,\
        filter_upper = 3,\
        filter_lower = 2,\
        epsilon_random = 0.1,
        alpha = 0):

        super().__init__(
             env, learning_rate, batch_size, buffer_size, \
                train_freq, gradient_steps, learning_starts, \
                None, 100, \
                temperature, sample_method, \
                use_filter, \
                device, continuous, tensorboard_log, verbose, \
                hidden_sizes, activation_fn, \
                initial_z, num_val_samples, \
                pessimistic_updates, \
                model_dir, validation_env, data_env, no_decay, \
                timeout_mask, filter_upper, filter_lower, epsilon_random)
        
        # create the teacher model
        if self.continuous:
            self.forward_teacher_policy = ContinuousForwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)
            self.backward_teacher_policy = ContinuousBackwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)
        else:
            self.forward_teacher_policy = DiscreteForwardPolicy(self.state_dim, self.action_dim, hidden_sizes, activation_fn, device = self.device).to(self.device)
            self.backward_teacher_policy = DiscreteBackwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)

        self.teacher_logZ = torch.tensor(initial_z, requires_grad = True, device = self.device, dtype=torch.float32)
        self.teacher_forward_optim = Adam(self.forward_teacher_policy.parameters(), lr=self.learning_rate)
        self.teacher_backward_optim = Adam(self.backward_teacher_policy.parameters(), lr=self.learning_rate)
        self.teacher_logZ_optim = Adam([self.teacher_logZ], lr=self.learning_rate * 100)
        self.teacher_forward_optim.param_groups[0]['initial_lr'] = self.learning_rate
        self.teacher_backward_optim.param_groups[0]['min_lr'] = self.learning_rate/100
        self.teacher_logZ_optim.param_groups[0]['initial_lr'] = self.learning_rate * 100
        self.teacher_logZ_optim.param_groups[0]['min_lr'] = self.learning_rate

        self.alpha = alpha # the weight for the teacher loss
        self.reward_augment = 1.0

      def teacher_predict(self, obs, reach_end = None, action_mask= None, use_mask = False): 
            # Query the forward poilcy for an action
            if self.continuous:
                  action = self.forward_teacher_policy(obs, reach_end, use_mask, epsilon = self.epsilon_random)
            else:
                  action = self.forward_teacher_policy(obs, action_mask, epsilon = self.epsilon_random)

            # Return the action
            return action.detach().cpu().numpy()
      
      @ override
      def collect_rollouts_parallel(self):
            # either collect the rollout from the teacher or the student
            buffer_rews_report = np.zeros(self.train_freq) 
            num_collected_steps, num_collected_episodes = 0, 0
            buffer_lens = np.zeros(self.train_freq, dtype = int)
            use_teacher = np.random.random() < 0.5 # 50% chance to use the teacher

            with torch.no_grad():
                  dones = np.zeros(self.train_freq, dtype = bool)
                  obs_buffer = []
                  action_buffer = []
                  next_obs_buffer = []
                  reward_buffer = []
                  augmented_rew_buffer = np.ones(self.train_freq, dtype = np.float32)
                  obs, _ = self.data_env.reset()

                  prev_obs = obs
                  ep_t = 0
                  while not np.all(dones):
                        obs_buffer.append(obs)
                        if use_teacher:
                              if self.continuous:
                                    action = self.teacher_predict(obs, reach_end =  (ep_t >= self.env.unwrapped.max_t - 1) * np.ones(self.train_freq) * (1-dones), use_mask=self.timeout_mask)
                              else:
                                    action = self.teacher_predict(obs, action_mask = self.env.unwrapped.get_forward_action_masks(obs))
                        else:
                              if self.continuous:
                                    action = self.predict(obs, reach_end =  (ep_t >= self.env.unwrapped.max_t - 1) * np.ones(self.train_freq) * (1-dones), use_mask=self.timeout_mask)
                              else:
                                    action = self.predict(obs, action_mask = self.env.unwrapped.get_forward_action_masks(obs))
                        
                        prev_obs = obs
                        obs, rew, d, t, augmented_rew = self.data_env.step(action) # if done, automatically reset
                        
                        action_buffer.append(action)
                        reward_buffer.append(rew * (~dones))
                        obs[d, :] = prev_obs[d, :]
                        next_obs_buffer.append(obs)

                        mask = d & (~dones) & (~t)
                        if np.any(mask):
                              buffer_lens[mask] = len(obs_buffer)
                              buffer_rews_report[mask] = reward_buffer[-1][mask] 
                              augmented_rew_values = np.array([entry['augmented_rew'] 
                                    for entry in augmented_rew['final_info'][mask]])
                              augmented_rew_buffer[mask] = augmented_rew_values
                              dones[mask] = True

                        num_collected_steps += np.sum(~dones)
                        ep_t += 1

                  num_collected_episodes += np.sum(dones)

                  self.memory.push_trajs(obs_buffer, action_buffer, next_obs_buffer, reward_buffer, buffer_lens, augmented_rew_buffer)

            self.logger['buffer_rews'] = buffer_rews_report
            self.logger['buffer_lens'] = buffer_lens
            
            return num_collected_steps, num_collected_episodes

      def train(self):
            # First epsilon * total step rounds for RL update, we use the basic PPO update with default hyper-parameters:
            if self.sample_method == 1:
                  self.memory.update_threshold(self.batch_size)
            
            traj_losses = 0
            traj_losses_std = 0
                  
            for gradient_step in range(self.gradient_steps):       
                  #  Sample from the rollout buffer
                  if self.sample_method == 1:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.biased_sample(self.batch_size)
                  elif self.sample_method == 2:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.generalized_biased_sample(self.batch_size)
                  elif self.sample_method == 3:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.mixed_priority_sample(self.batch_size)
                  else:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.sample(self.batch_size)
                  
                  # Efficient version
                  batch_obs_pad = pad_sequence(batch_obs, batch_first = True)
                  batch_acts_pad = pad_sequence(batch_acts, batch_first = True)
                  batch_next_obs_pad = pad_sequence(batch_next_obs, batch_first = True)

                  lengths = torch.tensor(np.array([len(obs) for obs in batch_obs]), device = self.device)

                  # Forward pass
                  if self.continuous:
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, self.env.unwrapped.max_t, lengths = lengths, use_mask = self.timeout_mask)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)

                        teacher_logPF = self.forward_teacher_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, self.env.unwrapped.max_t, lengths = lengths, use_mask = self.timeout_mask)
                        teacher_logPB = self.backward_teacher_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
                  else:
                        forward_mask = self.env.unwrapped.get_forward_action_masks(batch_obs_pad)
                        backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, forward_mask, lengths = lengths)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

                        teacher_logPF = self.forward_teacher_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, forward_mask, lengths = lengths)
                        teacher_logPB = self.backward_teacher_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

                  logPF = torch.sum(logPF, dim = 1)
                  logPB = torch.sum(logPB, dim = 1)

                  teacher_logPF = torch.sum(teacher_logPF, dim = 1)
                  teacher_logPB = torch.sum(teacher_logPB, dim = 1)

                  batch_rews = torch.tensor(batch_rews, device = self.device)
                  batch_augmented_rews = torch.tensor(batch_augmented_rews, device = self.device)

                  if self.multiply_temperature:
                        log_reward = torch.log(batch_rews)/self.temperature
                  else:
                        log_reward = torch.log(batch_rews + self.temperature * batch_augmented_rews)

                  if torch.any(torch.isinf(logPF)) or torch.any(torch.isinf(logPB)):
                        raise ValueError("Infinite logprobs found")
                  
                  loss = (logPF + self.logZ - logPB - log_reward).pow(2)

                  if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                        raise ValueError(f"Invalid loss found, loss: {loss}")
                  
                  traj_losses += loss.mean().detach().item()
                  traj_losses_std += loss.std().detach().item()

                  loss_detached = loss.detach().cpu().numpy()

                  teacher_log_reward = self.alpha * log_reward + (1-self.alpha) * loss.detach() - torch.log(loss.detach() + 1e-8)
                  teacher_loss = (teacher_logPF + self.teacher_logZ - teacher_logPB - teacher_log_reward).pow(2)

                  if self.sample_method >= 2:
                        self.memory.update_priority(batch_idx, loss_detached)

                  self.memory.push_train_logs(logPF.detach().cpu().numpy(), logPB.detach().cpu().numpy(), loss_detached, batch_rews.cpu().numpy(), self.logZ.item())

                  # Compute traj loss
                  loss = loss.mean()
                  teacher_loss = teacher_loss.mean()

                  if self.verbose:
                        print("Teacher loss: ", teacher_loss)
                  # print(loss)

                  # Optimize the models
                  self.forward_optim.zero_grad()
                  self.backward_optim.zero_grad()
                  self.logZ_optim.zero_grad()

                  self.teacher_forward_optim.zero_grad()
                  self.teacher_backward_optim.zero_grad()
                  self.teacher_logZ_optim.zero_grad()


                  loss.backward()

                  teacher_loss.backward()
                 


                  # torch.nn.utils.clip_grad_norm_(self.forward_policy.parameters(), 1e1)
                  # torch.nn.utils.clip_grad_norm_(self.backward_policy.parameters(), 1e1)
                  # torch.nn.utils.clip_grad_norm_(self.logZ, 1e1)

                  # clip gradients for teacher, follow the github repository of the adaptive teacher
                  torch.nn.utils.clip_grad_norm_(self.forward_teacher_policy.parameters(), 1e1)
                  torch.nn.utils.clip_grad_norm_(self.backward_teacher_policy.parameters(), 1e1)
                  torch.nn.utils.clip_grad_norm_(self.teacher_logZ, 1e1)

                  self.forward_optim.step()
                  self.backward_optim.step()
                  self.logZ_optim.step()

                  self.teacher_forward_optim.step()
                  self.teacher_backward_optim.step()
                  self.teacher_logZ_optim.step()

                  self.logger['log_Z'].append(self.logZ.item())

            self._n_updates += gradient_step
            # Log actor loss
            self.logger['traj_losses'].append(traj_losses/self.gradient_steps)
            self.logger['traj_losses_std'].append(traj_losses_std/self.gradient_steps)
      
      def save(self, model_dir):
            if not os.path.exists(model_dir):
                  os.makedirs(model_dir)
            torch.save(self.forward_policy.state_dict(), f'{model_dir}/forward_policy.pth')
            torch.save(self.backward_policy.state_dict(), f'{model_dir}/backward_policy.pth')
            torch.save(self.logZ, f'{model_dir}/logZ.pth')
            # save the optimizer
            torch.save(self.forward_optim.state_dict(), f'{model_dir}/forward_optim.pth')
            torch.save(self.backward_optim.state_dict(), f'{model_dir}/backward_optim.pth')
            torch.save(self.logZ_optim.state_dict(), f'{model_dir}/logZ_optim.pth')
            torch.save(self.forward_teacher_policy.state_dict(), f'{model_dir}/teacher_forward_policy.pth')
            torch.save(self.backward_teacher_policy.state_dict(), f'{model_dir}/teacher_backward_policy.pth')
            torch.save(self.teacher_logZ, f'{model_dir}/teacher_logZ.pth')
            torch.save(self.teacher_forward_optim.state_dict(), f'{model_dir}/teacher_forward_optim.pth')
            torch.save(self.teacher_backward_optim.state_dict(), f'{model_dir}/teacher_backward_optim.pth')
            torch.save(self.teacher_logZ_optim.state_dict(), f'{model_dir}/teacher_logZ_optim.pth')
            # save the i_so_far, e_so_far, t_so_far
            current_progress = (self.logger['t_so_far'], self.logger['i_so_far'], self.logger['e_so_far'])
            with open(f'{model_dir}/progress.pkl', 'wb') as f:
                  pickle.dump(current_progress, f)

      def save_replay_buffer(self, model_dir):
            # Save our model and memory at the end of training
            self.memory.save(model_dir)

      def load(self, model_dir, load_optim = False):
            # Load our model 
            self.forward_policy.load_state_dict(torch.load(f'{model_dir}/forward_policy.pth'))
            self.backward_policy.load_state_dict(torch.load(f'{model_dir}/backward_policy.pth'))
            self.logZ = torch.load(f'{model_dir}/logZ.pth')

            if load_optim:
                  self.forward_optim.load_state_dict(torch.load(f'{model_dir}/forward_optim.pth'))
                  self.backward_optim.load_state_dict(torch.load(f'{model_dir}/backward_optim.pth'))
                  self.logZ_optim.load_state_dict(torch.load(f'{model_dir}/logZ_optim.pth'))
                  self.forward_teacher_policy.load_state_dict(torch.load(f'{model_dir}/teacher_forward_policy.pth'))
                  self.backward_teacher_policy.load_state_dict(torch.load(f'{model_dir}/teacher_backward_policy.pth'))
                  self.teacher_logZ = torch.load(f'{model_dir}/teacher_logZ.pth')
                  self.teacher_forward_optim.load_state_dict(torch.load(f'{model_dir}/teacher_forward_optim.pth'))
                  self.teacher_backward_optim.load_state_dict(torch.load(f'{model_dir}/teacher_backward_optim.pth'))
                  self.teacher_logZ_optim.load_state_dict(torch.load(f'{model_dir}/teacher_logZ_optim.pth'))

                  # load the i_so_far, e_so_far, t_so_far
                  with open(f'{model_dir}/progress.pkl', 'rb') as f:
                        t_so_far, i_so_far, e_so_far = pickle.load(f)
            
                  return t_so_far, i_so_far, e_so_far
            return 0, 0, 0
      
      def load_replay_buffer(self, model_dir):
            # Load the memory
            self.memory.load(model_dir)