import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import defaultdict
from rlf.algos.on_policy.on_policy_base import OnPolicy
import numpy as np

class PPO(OnPolicy):
    def update(self, rollouts):
        self._compute_returns(rollouts)
        advantages = rollouts.compute_advantages()

        use_clipped_value_loss = True

        log_vals = defaultdict(lambda: 0)

        for e in range(self._arg('num_epochs')):
            data_generator = rollouts.get_generator(advantages,
                    self._arg('num_mini_batch'))
            if self._arg('add_expert_regularizer') or self._arg('add_cdf_regularizer') or self._arg('add_pdf_regularizer'):
                expert_data_generator = self.policy.expert_loader
                # expert_data_generator.reset()
            
            if self._arg('add_bc_model_regularizer'):
                bc_policy = self.policy.bc_policy

            for sample in data_generator:
                ac_eval = self.policy.evaluate_actions(sample['state'],
                        sample['other_state'],
                        sample['hxs'], sample['mask'],
                        sample['action'])

                ratio = torch.exp(ac_eval['log_prob'] - sample['prev_log_prob'])
                surr1 = ratio * sample['adv']
                surr2 = torch.clamp(ratio,
                        1.0 - self._arg('clip_param'),
                        1.0 + self._arg('clip_param')) * sample['adv']
                action_loss = -torch.min(surr1, surr2).mean(0)

                if self._arg('add_expert_regularizer'):
                    # check whether the action space is discrete, raise error if not
                    assert self.action_space.__class__.__name__ == 'Discrete', 'Expert KL loss is only supported for discrete action space'
                    expert_sample = next(iter(expert_data_generator))
                    # expert sample mask will be all ones
                    expert_sample_states = expert_sample['state'].to(self.args.device)
                    expert_sample_mask = torch.ones(expert_sample['state'].shape[0], 1).to(self.args.device)
                    predicted_ac_dist, _, _ = self.policy(expert_sample_states, {}, {}, expert_sample_mask)
                    predicted_ac_prob = predicted_ac_dist.probs


                    # seperate the predicted_ac_space into two parts: the expert action space and the rest
                    predicted_ac_prob_expert = predicted_ac_prob[:, 0:4]

                    if self._arg('add_expert_restricted_regularizer'):
                        # esa: expert sample actions
                        esa_labels = expert_sample['actions'][:, 0].to(int).to(self.args.device)
                        whole_labels = [[0, 1, 2, 3] for _ in range(esa_labels.shape[0])]
                        for i in range(esa_labels.shape[0]):
                            whole_labels[i].pop(esa_labels[i])
                        whole_labels = torch.tensor(whole_labels).to(self.args.device)
                        if self._arg("restricted_regularizer_sum_log"):
                            predicted_ac_logits = predicted_ac_dist.logits
                            # seperate the predicted_ac_space into two parts: the expert action space and the rest
                            predicted_ac_logits_expert = predicted_ac_logits[:, 0:4]
                            wrong_logits = torch.gather(predicted_ac_logits_expert, 1, whole_labels)
                            restricted_regularizer = wrong_logits.sum(dim=1)
                        elif self._arg("restricted_regularizer_sum_prob"):
                            self.sum_prob_update_ts += 1
                            if self.sum_prob_update_ts % self.sum_prob_update_freq == 0:
                                predicted_ac_logits = predicted_ac_dist.probs
                                # seperate the predicted_ac_space into two parts: the expert action space and the rest
                                predicted_ac_logits_expert = predicted_ac_logits[:, 0:4]
                                wrong_probs = torch.gather(predicted_ac_logits_expert, 1, whole_labels)
                                restricted_regularizer = wrong_probs.sum(dim=1)
                                crop_threshold = self._arg("restricted_regularizer_sum_prob_clamp_threshold")
                                restricted_regularizer = torch.clamp(restricted_regularizer, crop_threshold, 1)

                                if self._arg('restricted_regularizer_sum_prob_wrap_log'):
                                    restricted_regularizer = torch.log(restricted_regularizer)
                            else:
                                restricted_regularizer = torch.tensor(0.0).to(self.args.device)
                        elif self._arg("restricted_regularizer_related_ac"):
                            predicted_ac_probs = predicted_ac_dist.probs
                            # seperate the predicted_ac_space into two parts: the expert action space and the rest
                            predicted_ac_probs_expert = predicted_ac_probs[:, 0:4]
                            wrong_probs = torch.gather(predicted_ac_probs_expert, 1, whole_labels)
                            restricted_regularizer = wrong_probs.sum(dim=1)
                            # there are 0-7, 8 actions in total, 0-3 are expert actions, 4-7 are non-expert actions
                            # if expert_ac=0, regularizer_ooc on sum(ac=5, ac=6)
                            # if expert_ac=1, regularizer_ooc on sum(ac=6, ac=7)
                            # if expert_ac=2, regularizer_ooc on sum(ac=4, ac=7)
                            # if expert_ac=3, regularizer_ooc on sum(ac=4, ac=5)
                            # so the mapping is (0->5,6), (1->6,7), (2->4,7), (3->4,5)
                            ooc_action_mapping = {0: [5,6], 1: [6,7], 2: [4,7], 3: [4,5]}
                            ooc_labels = [ooc_action_mapping[expert_ac.item()] for expert_ac in esa_labels]
                            ooc_labels = torch.tensor(ooc_labels).to(self.args.device)
                            ooc_probs = torch.gather(predicted_ac_probs, 1, ooc_labels)
                            restricted_regularizer += ooc_probs.sum(dim=1)


                        else:
                            predicted_ac_probs = predicted_ac_dist.probs
                            # seperate the predicted_ac_space into two parts: the expert action space and the rest
                            predicted_ac_probs_expert = predicted_ac_probs[:, 0:4]
                            wrong_probs = torch.gather(predicted_ac_probs_expert, 1, whole_labels)
                            if self._arg('reg_with_clamp'):
                                restricted_regularizer = torch.log(torch.clamp(wrong_probs.sum(1), 1e-10, 1))
                            else:
                                restricted_regularizer = torch.log(wrong_probs.sum(dim=1) + 1e-10)
                        restricted_regularizer = restricted_regularizer.mean()
                        action_loss += self.exres_coef * restricted_regularizer
                        # whole_labels - esa_labels for dim 1
                    else:
                        if self._arg('add_kl_regularizer'):
                            # use expert_sample['actions'] to generate the expert log_prob.
                            EPS = 1e-5
                            # E.g. expert samples['actions'] = [1, 3], where action space=(0, 1, 2, 3). Then the log_prob = [[eps, eps, 1, eps], [eps, eps, eps, 1]]
                            # get action labels
                            expert_sample_actions_labels = expert_sample['actions'][:, 0].to(int)
                            # TODO: this action space needed to be replaced with the actual action space
                            expert_sample_log_prob = F.one_hot(expert_sample_actions_labels, 4).to(self.args.device)
                            expert_sample_log_prob = expert_sample_log_prob * (1 - self.action_space.n * EPS) + EPS
                            expert_sample_log_prob = torch.log(expert_sample_log_prob)
                            
                            regularizer = self.regularizer_coef * F.kl_div(predicted_ac_prob_expert, expert_sample_log_prob, 
                                                                        reduction='batchmean', log_target=True)
                        elif self._arg('add_ce_regularizer'):
                            expert_sample_actions_labels = expert_sample['actions'][:, 0].to(int)
                            expert_sample_log_prob = F.one_hot(expert_sample_actions_labels, 4).to(self.args.device)
                            expert_sample_log_prob = expert_sample_log_prob.to(torch.float64)
                            ce_loss = - predicted_ac_prob_expert * expert_sample_log_prob
                            ce_loss = ce_loss.mean()
                            regularizer = self.regularizer_coef * ce_loss
                        else:
                            regularizer = torch.tensor(0.0).to(self.args.device)

                        action_loss += regularizer

                        if self._arg('add_non_expert_entropy'):
                            # esa: expert sample actions
                            esa_labels = expert_sample['actions'][:, 0].to(int).to(self.args.device)
                            esa_labels = esa_labels[:, None]
                            esa_labels_extend = [[i for i in range(predicted_ac_prob_expert.shape[1], self.action_space.n)] 
                                                                for _ in range(esa_labels.shape[0])]
                            esa_labels_extend = torch.tensor(esa_labels_extend).to(self.args.device)
                            esa_labels_whole = torch.cat((esa_labels, esa_labels_extend), dim=1)
                            predicted_ac_log_prob_for_true_actions = torch.gather(predicted_ac_prob, 1, esa_labels_whole)
                            
                            def get_entropy_loss(p):
                                p = p/p.sum(dim=1, keepdim=True)
                                ent = -p * torch.log(p + 1e-6)
                                return ent.sum(dim=1).mean()
                            non_expert_entropy_loss = get_entropy_loss(predicted_ac_log_prob_for_true_actions)
                            action_loss -= self.non_expert_entropy_coef * non_expert_entropy_loss
                
                if self._arg('add_cdf_regularizer'):

                    def get_cdf_value_in_range(dist, box_range):
                        """
                        Compute the CDF value in the range for multi-dimensional continuous actions.
                        Args:
                            dist (torch.distributions.Normal): The predicted action distribution.
                            box_range (torch.Tensor, shape=[batch_size, d, 2]): The range of the action space.
                        Returns:
                            torch.Tensor: CDF value in the range.
                        """
                        # Get lower and upper bounds from the box_range
                        lower_bounds = box_range[:, :, 0]  # Shape: [batch_size, action_dim]
                        upper_bounds = box_range[:, :, 1]  # Shape: [batch_size, action_dim]

                        if self.policy.squash_action:
                            lower_bounds = self.policy.scale_action(lower_bounds)
                            upper_bounds = self.policy.scale_action(upper_bounds)
                        
                        # Compute CDF at lower and upper bounds
                        cdf_lower = dist.cdf(lower_bounds)  # Shape: [batch_size, action_dim]
                        cdf_upper = dist.cdf(upper_bounds)  # Shape: [batch_size, action_dim]
                        
                        # Compute probability mass in the range
                        prob_in_range = (cdf_upper - cdf_lower).prod(dim=1)  # Average across dimensions per batch
                        
                        return prob_in_range

                    def cdf_regularizer_loss(dist, expert_action, up_bd=0.1, eps=1e-3):
                        """
                        Compute the actor regularizer loss for multi-dimensional continuous actions.
                        Args:
                            dist (torch.distributions.Normal): The predicted action distribution.
                            expert_action (torch.Tensor): Expert action vector (shape: [batch_size, action_dim]).
                            up_bd (float): Upper bound for the constrained space.
                            eps (float): Range around expert action for CDF.
                        Returns:
                            torch.Tensor: Regularizer loss.
                        """
                        batch_size, action_dim = expert_action.shape

                        # Define box constraint range: [-up_bd, up_bd]
                        box_constraint_range = torch.tensor(
                            [[[-up_bd, up_bd] for _ in range(action_dim)] for _ in range(batch_size)],
                            dtype=expert_action.dtype,
                            device=expert_action.device
                        )
                        cdf_in_range = get_cdf_value_in_range(dist, box_constraint_range)

                        # Define expert action range: [action - eps, action + eps]
                        expert_range = torch.stack([
                            expert_action - eps, 
                            expert_action + eps
                        ], dim=2)  # Shape: [batch_size, action_dim, 2]
                        
                        cdf_expert = get_cdf_value_in_range(dist, expert_range)

                        # Compute the regularizer loss
                        cdf_reg_loss = (cdf_in_range - cdf_expert).mean()

                        # Compute final regularizer loss
                        return cdf_reg_loss
                        

                    assert np.all(self.action_space.low == -1) and np.all(self.action_space.high == 1), "The action space must be [-1, 1] for pdf regularizer"
                    expert_sample = next(iter(expert_data_generator))
                    expert_sample_states = expert_sample['state'].to(self.args.device)
                    expert_sample_mask = torch.ones(expert_sample['state'].shape[0], 1).to(self.args.device)
                    expert_sample_actions = expert_sample['actions'].to(self.args.device)
                    predicted_ac_dist, _, _ = self.policy(expert_sample_states, {}, {}, expert_sample_mask)
                    # predicted_ac_mean = predicted_ac_dist.mean
                    # predicted_ac_std = predicted_ac_dist.stddev
                    # box_regularizer = regularizer_loss_multi(predicted_ac_mean, predicted_ac_std, expert_sample_actions, up_bd=self.reg_const_bound)
                    cdf_regularizer = cdf_regularizer_loss(predicted_ac_dist, expert_sample_actions, up_bd=self.reg_const_bound)
                    cdf_regularizer = self.cdf_coef * cdf_regularizer
                    action_loss += cdf_regularizer

                if self._arg('add_pdf_regularizer'):
                    def constrained_sampling_regularizer_loss(dist, expert_action, constraint_bound=0.1, eps=1e-3, n_samples=100):
                        """
                        Compute the regularizer loss by sampling uniformly from a constrained region 
                        and summing up their probabilities.
                        
                        Args:
                            dist (torch.distributions.Normal): The predicted action distribution.
                            expert_action (torch.Tensor): Expert action vector (shape: [batch_size, action_dim]).
                            constraint_bound (float): Upper and lower bounds for the constraint region.
                            eps (float): Epsilon range around expert actions to exclude.
                            n_samples (int): Number of samples to draw from the constrained region.
                        
                        Returns:
                            torch.Tensor: Regularizer loss.
                        """
                        batch_size, action_dim = expert_action.shape

                        # Define constraint bounds
                        lower_bound = -constraint_bound
                        upper_bound = constraint_bound
                        
                        # Sample uniformly from the full constraint region
                        uniform_samples = torch.empty((n_samples, batch_size, action_dim), device=expert_action.device).uniform_(lower_bound, upper_bound)
                        
                        # Mask out the expert action epsilon range
                        expert_lower = expert_action.unsqueeze(0) - eps  # Shape: [1, batch_size, action_dim]
                        expert_upper = expert_action.unsqueeze(0) + eps  # Shape: [1, batch_size, action_dim]
                        
                        mask = (uniform_samples < expert_lower) | (uniform_samples > expert_upper)
                        mask = mask.all(dim=-1)  # Shape: [n_samples, batch_size]

                        # Evaluate log probabilities of valid samples
                        if self.policy.squash_action:
                            uniform_samples = self.policy.scale_action(uniform_samples)
                        idv_probs = dist.log_prob(uniform_samples).exp().sum(dim=-1)  # Shape: [n_samples, batch_size]
                        
                        idv_probs_valid = idv_probs * mask  # Sum across action dimensions, shape: [n_samples, batch_size]
                        prob_sum = idv_probs_valid.sum(dim=0) # shape: [batch_size]

                        reg_loss = prob_sum / mask.sum(dim=0)

                        # if any of the batch has no valid samples, set the loss to 0
                        reg_loss[mask.sum(dim=0) == 0] = 0
                        return reg_loss.mean()
                    
                    assert np.all(self.action_space.low == -1) and np.all(self.action_space.high == 1), "The action space must be [-1, 1] for pdf regularizer"
                    expert_sample = next(iter(expert_data_generator))
                    expert_sample_states = expert_sample['state'].to(self.args.device)
                    expert_sample_mask = torch.ones(expert_sample['state'].shape[0], 1).to(self.args.device)
                    expert_sample_actions = expert_sample['actions'].to(self.args.device)
                    predicted_ac_dist, _, _ = self.policy(expert_sample_states, {}, {}, expert_sample_mask)

                    pdf_regularizer = constrained_sampling_regularizer_loss(predicted_ac_dist, 
                        expert_sample_actions, constraint_bound=self.reg_const_bound, 
                        eps=self.args.pdf_eps, n_samples=self.args.pdf_n_samples)
                    pdf_regularizer = self.pdf_coef * pdf_regularizer
                    action_loss += pdf_regularizer

                # =======================

                if self._arg('add_bc_model_regularizer'):
                    
                    if 'MiniGrid' in self._arg('env_name'):
                        # Get all the data from our batch sample
                        bc_ac_eval = bc_policy(state=sample['state'], rnn_hxs=sample['hxs'], mask=sample['mask'])
                        
                        bc_ac_logits = bc_ac_eval[0]
                        # esa: expert sample actions
                        bc_esa_labels = bc_ac_logits.argmax(dim=1)
                        whole_labels = [[0, 1, 2, 3] for _ in range(bc_esa_labels.shape[0])]
                        for i in range(bc_esa_labels.shape[0]):
                            whole_labels[i].pop(bc_esa_labels[i])
                        whole_labels = torch.tensor(whole_labels).to(self.args.device)
                        

                        bc_ac_eval_dist, _, _ = self.policy(sample['state'],
                                sample['other_state'],
                                sample['hxs'], sample['mask'])
                        # Get the predicted action probabilities
                        predicted_ac_probs = bc_ac_eval_dist.probs
                        print('predicted_ac_probs: ', predicted_ac_probs)
                        predicted_ac_probs_expert = predicted_ac_probs
                        wrong_probs = torch.gather(predicted_ac_probs_expert, 1, whole_labels)
                        bc_restricted_regularizer = wrong_probs.sum(dim=1)
                        if self._arg('add_bc_model_regularizer_log_sum_pi'):
                            bc_restricted_regularizer = torch.log(bc_restricted_regularizer + 1e-10)
                        bc_restricted_regularizer = bc_restricted_regularizer.mean()
                        action_loss += self.bc_model_reg_coef * bc_restricted_regularizer
                    elif 'FetchPickAndPlaceDiffHoldoutTS150' in self._arg('env_name') or 'MBRLmaze2d' in self._arg('env_name') or 'FetchPushEnvCustomTS500' in self._arg('env_name'):
                        # Here we do cdf regularizer
                        def get_cdf_value_in_range(dist, box_range):
                            """
                            Compute the CDF value in the range for multi-dimensional continuous actions.
                            Args:
                                dist (torch.distributions.Normal): The predicted action distribution.
                                box_range (torch.Tensor, shape=[batch_size, d, 2]): The range of the action space.
                            Returns:
                                torch.Tensor: CDF value in the range.
                            """
                            # Get lower and upper bounds from the box_range
                            lower_bounds = box_range[:, :, 0]  # Shape: [batch_size, action_dim]
                            upper_bounds = box_range[:, :, 1]  # Shape: [batch_size, action_dim]

                            if self.policy.squash_action:
                                lower_bounds = self.policy.scale_action(lower_bounds)
                                upper_bounds = self.policy.scale_action(upper_bounds)
                            
                            # Compute CDF at lower and upper bounds
                            cdf_lower = dist.cdf(lower_bounds)  # Shape: [batch_size, action_dim]
                            cdf_upper = dist.cdf(upper_bounds)  # Shape: [batch_size, action_dim]
                            
                            # Compute probability mass in the range
                            prob_in_range = (cdf_upper - cdf_lower).prod(dim=1)  # Average across dimensions per batch
                            
                            return prob_in_range

                        def cdf_regularizer_loss(dist, expert_action, up_bd=0.1, eps=1e-3):
                            """
                            Compute the actor regularizer loss for multi-dimensional continuous actions.
                            Args:
                                dist (torch.distributions.Normal): The predicted action distribution.
                                expert_action (torch.Tensor): Expert action vector (shape: [batch_size, action_dim]).
                                up_bd (float): Upper bound for the constrained space.
                                eps (float): Range around expert action for CDF.
                            Returns:
                                torch.Tensor: Regularizer loss.
                            """
                            batch_size, action_dim = expert_action.shape

                            # Define box constraint range: [-up_bd, up_bd]
                            box_constraint_range = torch.tensor(
                                [[[-up_bd, up_bd] for _ in range(action_dim)] for _ in range(batch_size)],
                                dtype=expert_action.dtype,
                                device=expert_action.device
                            )
                            cdf_in_range = get_cdf_value_in_range(dist, box_constraint_range)

                            # Define expert action range: [action - eps, action + eps]
                            expert_range = torch.stack([
                                expert_action - eps, 
                                expert_action + eps
                            ], dim=2)  # Shape: [batch_size, action_dim, 2]

                            # crop expert_range
                            expert_range = expert_range.clamp(-up_bd, up_bd)
                            
                            cdf_expert = get_cdf_value_in_range(dist, expert_range)

                            # Compute the regularizer loss
                            cdf_reg_loss = (cdf_in_range - cdf_expert).mean()

                            # Compute final regularizer loss
                            return cdf_reg_loss
                            
                        # use bc_policy to get expert_sample_actions
                        expert_sample_actions = bc_policy(state=sample['state'], rnn_hxs=sample['hxs'], mask=sample['mask'])
                        # expert_sample_actions = expert_sample['actions'].to(self.args.device)
                        predicted_ac_dist, _, _ = self.policy(sample['state'], sample['hxs'], {}, sample['mask'])
                        # predicted_ac_mean = predicted_ac_dist.mean
                        # predicted_ac_std = predicted_ac_dist.stddev
                        # box_regularizer = regularizer_loss_multi(predicted_ac_mean, predicted_ac_std, expert_sample_actions, up_bd=self.reg_const_bound)
                        if self._arg('bc_model_regularizer_bound') is None:
                            if 'FetchPickAndPlaceDiffHoldoutTS150' in self._arg('env_name'):
                                reg_const_bound = 0.1
                            elif 'MBRLmaze2d' in self._arg('env_name'):
                                reg_const_bound = 0.1
                            elif 'FetchPushEnvCustomTS500' in self._arg('env_name'):
                                reg_const_bound = 0.05
                        else:
                            reg_const_bound = self._arg('bc_model_regularizer_bound')
                        bc_restricted_regularizer = cdf_regularizer_loss(predicted_ac_dist, expert_sample_actions[0], up_bd=reg_const_bound)
                        bc_restricted_regularizer = self.bc_model_reg_coef * bc_restricted_regularizer
                        action_loss += bc_restricted_regularizer

                # =========================

                if use_clipped_value_loss:
                    value_pred_clipped = sample['value'] + (ac_eval['value'] - sample['value']).clamp(
                                    -self._arg('clip_param'),
                                    self._arg('clip_param'))
                    value_losses = (ac_eval['value'] - sample['return']).pow(2)
                    value_losses_clipped = (
                        value_pred_clipped - sample['return']).pow(2)
                    value_loss = 0.5 * torch.max(value_losses,
                                                 value_losses_clipped).mean()
                else:
                    value_loss = 0.5 * (sample['return'] - ac_eval['value']).pow(2).mean()

                loss = (value_loss * self._arg('value_loss_coef') + action_loss -
                     ac_eval['ent'].mean() * self._arg('entropy_coef'))

                self._standard_step(loss)

                log_vals['value_loss'] += value_loss.sum().item()
                log_vals['action_loss'] += action_loss.sum().item()
                log_vals['dist_entropy'] += ac_eval['ent'].mean().item()
                if self._arg('add_expert_regularizer'):
                    if self._arg('add_expert_restricted_regularizer'):
                        log_vals['restricted_regularizer'] += restricted_regularizer.sum().item()
                    else:
                        log_vals['actor_regularizer'] += regularizer.sum().item()
                        if self._arg('add_non_expert_entropy'):
                            log_vals['non_expert_entropy_loss'] += non_expert_entropy_loss.sum().item()
                
                if self._arg('add_cdf_regularizer'):
                    log_vals['cdf_regularizer'] += cdf_regularizer.sum().item()
                
                if self._arg('add_pdf_regularizer'):
                    log_vals['pdf_regularizer'] += pdf_regularizer.sum().item()
                
                if self._arg("add_bc_model_regularizer"):
                    log_vals['bc_model_regularizer'] += bc_restricted_regularizer.sum().item()


        num_updates = self._arg('num_epochs') * self._arg('num_mini_batch')
        for k in log_vals:
            log_vals[k] /= num_updates

        return log_vals

    def get_add_args(self, parser):
        super().get_add_args(parser)
        parser.add_argument(f"--{self.arg_prefix}clip-param",
            type=float,
            default=0.2,
            help='ppo clip parameter')

        parser.add_argument(f"--{self.arg_prefix}entropy-coef",
            type=float,
            default=0.01,
            help='entropy term coefficient (old default: 0.01)')

        parser.add_argument(f"--{self.arg_prefix}value-loss-coef",
            type=float,
            default=0.5,
            help='value loss coefficient')

