import torch
from rlkit import conf
from rlkit.core.logging import red
from rlkit.launchers.pipeline import Pipeline, PipelineCtx
from rlkit.policies.gaussian_policy import GaussianPolicy, UnnormalizeGaussianPolicy
from rlkit.torch.algorithms.sac.mgpac import MGIQLPipeline, MGPacTrainer
from rlkit.torch.distributions import Delta, GaussianMixture
import rlkit.torch.pytorch_util as ptu
import os.path as osp
from torch import optim

class MGPacTruncTrainer(MGPacTrainer):
    def __init__(
        self,
        policy,
        qfs,
        target_qfs,
        discount=0.99,
        reward_scale=1,
        policy_lr=0.001,
        qf_lr=0.001,
        optimizer_class=optim.Adam,
        soft_target_tau=0.01,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        beta_LB=0.5,
        delta_range=None,
        target_quantile=0.7,
        num_delta=None,
        action_selection_mode="max_from_both",
        use_max_lambda=False,
        IQN=True,
    ):
        super().__init__(
            policy,
            qfs,
            target_qfs,
            discount,
            reward_scale,
            policy_lr,
            qf_lr,
            optimizer_class,
            soft_target_tau,
            target_update_period,
            plotter,
            render_eval_paths,
            beta_LB,
            delta_range,
            target_quantile,
            num_delta,
            action_selection_mode,
            use_max_lambda,
            IQN,
        )

    def get_pessimistic_action(self, obs) -> Delta:
        dist: GaussianMixture = self.get_action_dist(obs)

        if self.delta_range == [0.0, 0.0]:  #! we should mask out low probability
            batch_size = obs.shape[0]
            obs_exp = obs.repeat_interleave(dist.num_gaussians, dim=0)
            qfs = self.calc_q_LB(
                obs_exp, dist.mean.reshape(batch_size * dist.num_gaussians, -1)
            ).reshape(batch_size, dist.num_gaussians)
            idx = torch.argmax(qfs, dim=1)
            selected_actions = dist.mean[torch.arange(len(idx)), idx]
            return Delta(selected_actions)

        self.sample_delta()

        if self.action_selection_mode == "jensen":
            jensen_proposal, jensen_value = self.compute_jensen_proposal(obs, dist)
            return Delta(jensen_proposal)
        elif self.action_selection_mode == "max":
            max_proposal, max_value = self.compute_max_proposal(obs, dist)
            return Delta(max_proposal)
        elif self.action_selection_mode == "max_from_both":
            jensen_proposal, jensen_value = self.compute_jensen_proposal(obs, dist)
            max_proposal, max_value = self.compute_max_proposal(obs, dist)

            with torch.no_grad():
                # [batch_size, 2, act_dim]
                proposal = torch.cat(
                    [max_proposal.unsqueeze(1), jensen_proposal.unsqueeze(1)], dim=1
                )
                # [batch_size, 2]
                value = torch.cat(
                    [max_value.unsqueeze(1), jensen_value.unsqueeze(1)], dim=1
                )

                idx = torch.argmax(value, dim=1)
                selected_actions = proposal[torch.arange(len(idx)), idx]
                if torch.any(torch.isnan(selected_actions)):
                    red("not good, found nan actions")
                    raise Exception("Action selection is NaN!")
            return Delta(selected_actions)
        else:
            raise NotImplementedError

    def compute_max_proposal(self, obs, dist: GaussianMixture) -> torch.tensor:
        # * preliminaries
        mu_beta = dist.mean
        weights = dist.weights

        num_gaussians = dist.num_gaussians
        batch_size = obs.shape[0]

        # * calculate delta. this is the m distance constraint. we require the Mahalanobis (m) distance to be <= this value.
        # [batch_size, num_gaussian, act_dim]
        Sigma_beta = torch.pow(dist.stddev, 2)
        # [batch_size, num_gaussian]
        log_weights = weights.log()
        # [batch_size, num_gaussian]
        log_p_mu = self.calc_log_p_mu(Sigma_beta)

        if self.use_max_lambda:
            pseudo_log_p_mu = (log_weights + log_p_mu).max(-1, keepdim=True)[0]
        else:
            pseudo_log_p_mu = (log_weights + log_p_mu).sum(-1, keepdim=True)

        # [batch_size, num_delta, num_gaussian]
        max_delta = 2 * (  #! refer to appendix in paper
            self.delta + (log_weights - pseudo_log_p_mu + log_p_mu).unsqueeze(1)
        ).clamp(min=0.0)

        # * calculate gradient of q lower bound w.r.t action
        mu_beta.requires_grad_()
        # [batch_size * num_gaussian, obs_dim]
        obs_exp = obs.repeat_interleave(num_gaussians, dim=0)
        # [batch_size * num_gaussian, act_dim]
        mu_beta_reshaped = mu_beta.reshape(-1, mu_beta.shape[-1])

        # Get the lower bound of the Q estimate
        # [batch_size * num_gaussian, 1, ensemble_size]
        q_LB = self.calc_q_LB(obs_exp, mu_beta_reshaped)
        # [batch_size * num_gaussian, 1]
        q_LB = q_LB.reshape(-1, num_gaussians)

        # Obtain the gradient of q_LB wrt to a
        # with a evaluated at mu_proposal
        grad = torch.autograd.grad(q_LB.sum(), mu_beta)  #! this returns a tuple!!
        # [batch_size, num_gaussian, act_dim]
        grad = grad[0]

        assert grad is not None
        assert mu_beta.shape == grad.shape

        # * calculate proposals
        denom = self.get_shift_denominator(grad, Sigma_beta)
        # [batch_size, num_gaussians, action_dim]
        direction = (torch.mul(Sigma_beta, grad) / denom).unsqueeze(1)

        # [batch_size, num_delta, num_gaussians, action_dim]
        delta_mu = torch.sqrt(2 * max_delta).unsqueeze(-1) * direction

        mu_proposal = torch.clamp(mu_beta.unsqueeze(1) + delta_mu, -1, 1).reshape(
            batch_size, self.num_delta * num_gaussians, -1
        )

        # * get the lower bounded q
        obs_exp = obs.repeat(self.num_delta * num_gaussians, 1)
        q_LB = self.calc_q_LB(obs_exp, mu_proposal.reshape(batch_size * self.num_delta * num_gaussians, -1))
        q_LB = q_LB.reshape(batch_size, num_gaussians * self.num_delta)
        # mask low probabilities
        q_LB[(weights.repeat(1, self.num_delta) < 0.2 / num_gaussians)] = -torch.inf

        # * argmax the proposals
        max_value, idx = torch.max(q_LB, dim=1)
        select_mu_proposal = mu_proposal[torch.arange(len(idx)), idx]

        return select_mu_proposal, max_value

    def compute_jensen_proposal(self, obs, dist: GaussianMixture) -> torch.tensor:
        # * preliminaries
        mean_per_comp = dist.mean
        weights = dist.weights
        batch_size = obs.shape[0]
        Sigma_beta = torch.pow(dist.stddev, 2) + 1e-6
        normalized_factor = (weights.unsqueeze(-1) / Sigma_beta).sum(
            1
        )  #! this is "A" in the paper
        mu_bar = (weights.unsqueeze(-1) / Sigma_beta * mean_per_comp).sum(
            1
        ) / normalized_factor

        # * calculate delta. this is the m distance constraint. we require the Mahalanobis (m) distance to be <= this value.
        # [batch_size, num_gaussian]
        # jensen_delta = -2 * self.tau + (weights * log_p_mu).sum(-1)
        jensen_delta = self.delta  # this is flexible

        # Obtain the change in mu
        pseudo_delta = (
            2 * jensen_delta
            - (weights * (torch.pow(mean_per_comp, 2) / Sigma_beta).sum(-1))
            .sum(1, keepdim=True)
            .unsqueeze(1)
            + (torch.pow(mu_bar, 2) * normalized_factor)
            .sum(1, keepdim=True)
            .unsqueeze(1)
        )
        mu_bar.requires_grad_()

        if torch.all(pseudo_delta < 0):
            return mu_bar, ptu.tensor([-torch.inf])

        # * calculate gradient of q lower bound w.r.t action
        q_LB = self.calc_q_LB(obs, mu_bar)
        # Obtain the gradient of q_LB wrt to a
        # with a evaluated at mu_proposal
        grad = torch.autograd.grad(q_LB.sum(), mu_bar)[0]

        assert grad is not None
        assert mu_bar.shape == grad.shape

        denom = self.get_shift_denominator(grad, 1 / normalized_factor)

        numerator = torch.sqrt((pseudo_delta).clamp(min=0.0))
        delta_mu = numerator * (
            torch.mul(1 / normalized_factor, grad) / denom
        ).unsqueeze(1)

        # * calculate proposals
        mu_proposal = torch.clamp(mu_bar.unsqueeze(1) + delta_mu, -1, 1)
        jensen_value = (delta_mu * grad).sum(-1) + q_LB.squeeze(-1)

        # * get the lower bounded q
        obs_exp = obs.repeat(self.num_delta, 1)
        q_LB = self.calc_q_LB(
            obs_exp, mu_proposal.reshape(batch_size * self.num_delta, -1)
        )
        q_LB = q_LB.reshape(batch_size, self.num_delta)
        q_LB[(pseudo_delta <= -1e-10).squeeze(-1)] = -torch.inf

        # * argmax the proposals
        jensen_value, idx = torch.max(q_LB, dim=1)
        select_mu_proposal = mu_proposal[torch.arange(len(idx)), idx]

        return select_mu_proposal, jensen_value


def load_checkpoint_trunc_policy(ctx: PipelineCtx):
    # the current policy is unnormalized. 
    iql_params = conf.CheckpointParams.MGTrunc

    try:
        params = torch.load(
            osp.join(
                conf.CheckpointParams.checkpoint_path,
                iql_params.path1,
                ctx.variant["env_id"],
                str(ctx.variant["seed"]),
                f"params.pt",
            ),
            map_location=ptu.device,
        )
    except FileNotFoundError:
        assert False
        params = torch.load(
            osp.join(
                conf.CheckpointParams.checkpoint_path,
                iql_params.path2,
                ctx.variant["env_id"],
                str(ctx.variant["seed"]),
                f"params.pt",
            ),
            map_location=ptu.device,
        )


    ctx.policy = params['trainer/policy']

   


MGTruncPipeline = Pipeline.from_(MGIQLPipeline, 'MGTruncPipeline')
MGTruncPipeline.replace('load_checkpoint_policy', load_checkpoint_trunc_policy)