import os
import pickle
import numpy as np
import torch
from nns import  logF, RND
from torch.optim import Adam

from gfn_base import GFlowNetBase
from torch.nn.utils.rnn import pad_sequence

class GAFlowNet(GFlowNetBase):
      def __init__(
                self, env, learning_rate = 1e-3, batch_size = 32, buffer_size = 10000,\
                train_freq=16, gradient_steps = 10, learning_starts = 100, \
                sample_method = 0, 
                use_filter = False,
                device = 'auto', continuous = True, tensorboard_log = None, verbose = False,
                hidden_sizes = [256, 256], activation_fn = torch.nn.ReLU, initial_z = 0.0,
                num_val_samples=0, \
                reward_scale=0.005, \
                model_dir=None, \
                validation_env=None, \
                data_env=None, \
                temperature=0, \
                no_decay = False,\
                timeout_mask = False,\
                filter_upper = 3,\
                filter_lower = 2,\
                epsilon_random = 0.1):
            
            super().__init__(
                env, learning_rate, batch_size, buffer_size, \
                train_freq, gradient_steps, learning_starts, \
                None, 100, \
                temperature, sample_method, \
                use_filter, \
                device, continuous, tensorboard_log, verbose, \
                hidden_sizes, activation_fn, \
                initial_z, num_val_samples, \
                0, \
                model_dir, validation_env, data_env, no_decay, \
                timeout_mask, filter_upper, filter_lower, epsilon_random)

            self.state_dim = env.observation_space.shape[0]
            # GAFN flow estimator
            self.flow_estimator = logF(self.state_dim, hidden_sizes, activation_fn, device = self.device).to(self.device)
            # GAFN intrinsic reward estimator
            self.intrinsic_reward_estimator = RND(state_dim=self.state_dim, reward_scale=reward_scale, device = self.device).to(self.device)

            self.flow_optim = Adam(self.flow_estimator.parameters(), lr=self.learning_rate)
            self.intrinsic_reward_optim = Adam(self.intrinsic_reward_estimator.parameters(), lr=self.learning_rate)

            self.flow_optim.param_groups[0]['initial_lr'] = self.learning_rate
            self.intrinsic_reward_optim.param_groups[0]['initial_lr'] = self.learning_rate

            self.flow_optim.param_groups[0]['min_lr'] = self.learning_rate / 100
            self.intrinsic_reward_optim.param_groups[0]['min_lr'] = self.learning_rate / 100
      
      def train(self):            
            if self.sample_method == 1:
                  self.memory.update_threshold(self.batch_size)
            traj_losses = 0
            traj_losses_std = 0
            for gradient_step in range(self.gradient_steps):
                  #  Sample from the rollout buffer
                  if self.sample_method == 1:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.biased_sample(self.batch_size)
                  elif self.sample_method == 2:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.generalized_biased_sample(self.batch_size)
                  elif self.sample_method == 3:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.mixed_priority_sample(self.batch_size)
                  else:
                        batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.sample(self.batch_size)
                  
                  batch_obs_pad = pad_sequence(batch_obs, batch_first = True)
                  batch_acts_pad = pad_sequence(batch_acts, batch_first = True)
                  batch_next_obs_pad = pad_sequence(batch_next_obs, batch_first = True)

                  lengths = torch.tensor(np.array([len(obs) for obs in batch_obs]), device = self.device)

                  if self.continuous:
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, self.env.unwrapped.max_t, lengths = lengths, use_mask = self.timeout_mask)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
                  else:
                        forward_mask = self.env.unwrapped.get_forward_action_masks(batch_obs_pad)
                        backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, forward_mask, lengths = lengths)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

                  # GAFN intrinsic rewards
                  flow = self.flow_estimator(batch_obs_pad).squeeze()
                  if torch.any(torch.isnan(flow)) or torch.any(torch.isinf(flow)):
                        print(flow)
                        raise ValueError("Invalid flow found.")

                  intrinsic_rewards = self.intrinsic_reward_estimator.compute_intrinsic_reward(batch_next_obs_pad, lengths)
                  intrinsic_rewards = torch.log(intrinsic_rewards) - flow
                  
                  if torch.any(torch.isnan(intrinsic_rewards)) and torch.any(torch.isinf(intrinsic_rewards)):
                        print(intrinsic_rewards)
                        raise ValueError("Invalid intrisic reward.")
                  

                  logPF = torch.sum(logPF, dim = 1)
                  max_val = torch.max(logPB, intrinsic_rewards)  # Compute the maximum value for stability
                  logPB = torch.sum(max_val + torch.log(torch.exp(logPB - max_val) + torch.exp(intrinsic_rewards - max_val)), dim=1)

                  batch_rews = torch.clamp(torch.tensor(batch_rews, device=self.device), min=1e-10)
                  batch_augmented_rews = torch.tensor(batch_augmented_rews, device=self.device)

                  log_reward = torch.log(batch_rews + self.temperature * batch_augmented_rews)

                  rnd_loss = self.intrinsic_reward_estimator.compute_loss(batch_next_obs_pad, lengths)

                  if torch.any(torch.isinf(logPF)) or torch.any(torch.isinf(logPB)):
                        raise ValueError("Infinite logprobs found")

                  loss = (logPF + self.logZ - logPB - log_reward).pow(2)

                  loss_detached = loss.detach().cpu().numpy()

                  if self.sample_method >= 2:
                        self.memory.update_priority(batch_idx, loss_detached)

                  self.memory.push_train_logs(logPF.detach().cpu().numpy(), logPB.detach().cpu().numpy(), loss_detached, batch_rews.cpu().numpy(), self.logZ.item())
                  
                  assert isinstance(loss, torch.Tensor)  # for type checker
                  traj_losses += loss.mean().detach().item()  # type: ignore[union-attr]
                  traj_losses_std += loss.std().detach().item()
                        
                  # Compute traj loss
                  loss = loss.mean() + rnd_loss

                  if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                        print(loss)
                        raise ValueError("Invalid loss found.")

                  # Optimize the models
                  self.forward_optim.zero_grad()
                  self.backward_optim.zero_grad()
                  self.logZ_optim.zero_grad()
                  self.flow_optim.zero_grad()
                  self.intrinsic_reward_optim.zero_grad()
                  loss.backward()

                  # torch.nn.utils.clip_grad_norm_(self.forward_policy.parameters(), 1e1)
                  # torch.nn.utils.clip_grad_norm_(self.backward_policy.parameters(), 1e1)
                  # torch.nn.utils.clip_grad_norm_(self.logZ, 1e1)

                  self.forward_optim.step()
                  self.backward_optim.step()
                  self.logZ_optim.step()
                  self.flow_optim.step()
                  self.intrinsic_reward_optim.step()

                  self.logger['log_Z'].append(self.logZ.item())

            self._n_updates += gradient_step
            # Log actor loss
            self.logger['traj_losses'].append(traj_losses/self.gradient_steps)
            self.logger['traj_losses_std'].append(traj_losses_std/self.gradient_steps)

      def save(self, model_dir):
            if not os.path.exists(model_dir):
                  os.makedirs(model_dir)
            torch.save(self.forward_policy.state_dict(), f'{model_dir}/forward_policy.pth')
            torch.save(self.backward_policy.state_dict(), f'{model_dir}/backward_policy.pth')
            torch.save(self.flow_estimator.state_dict(), f'{model_dir}/logF.pth')
            torch.save(self.logZ, f'{model_dir}/logZ.pth')
            # save the optimizer
            torch.save(self.forward_optim.state_dict(), f'{model_dir}/forward_optim.pth')
            torch.save(self.backward_optim.state_dict(), f'{model_dir}/backward_optim.pth')
            torch.save(self.flow_optim.state_dict(), f'{model_dir}/logF_optim.pth')
            torch.save(self.logZ_optim.state_dict(), f'{model_dir}/logZ_optim.pth')
            # save the i_so_far, e_so_far, t_so_far
            current_progress = (self.logger['t_so_far'], self.logger['i_so_far'], self.logger['e_so_far'])
            with open(f'{model_dir}/progress.pkl', 'wb') as f:
                  pickle.dump(current_progress, f)

            
      def save_replay_buffer(self, model_dir):
            # Save our model and memory at the end of training
            self.memory.save(model_dir)

      def load(self, model_dir, load_optim = False):
            # Load our model and memory at the end of training
            self.forward_policy.load_state_dict(torch.load(f'{model_dir}/forward_policy.pth'))
            self.backward_policy.load_state_dict(torch.load(f'{model_dir}/backward_policy.pth'))
            self.flow_estimator.load_state_dict(torch.load(f'{model_dir}/logF.pth'))
            self.logZ = torch.load(f'{model_dir}/logZ.pth')

            if load_optim:
                  self.forward_optim.load_state_dict(torch.load(f'{model_dir}/forward_optim.pth'))
                  self.backward_optim.load_state_dict(torch.load(f'{model_dir}/backward_optim.pth'))
                  self.flow_optim.load_state_dict(torch.load(f'{model_dir}/logF_optim.pth'))
                  self.logZ_optim.load_state_dict(torch.load(f'{model_dir}/logZ_optim.pth'))

                  # load the i_so_far, e_so_far, t_so_far
                  with open(f'{model_dir}/progress.pkl', 'rb') as f:
                        t_so_far, i_so_far, e_so_far = pickle.load(f)
            
                  return t_so_far, i_so_far, e_so_far
            return 0, 0, 0


      def load_replay_buffer(self, model_dir):
            # Load our model and memory at the end of training
            self.memory.load(model_dir)
