""" 
PAC with a Q Lower Bound, with pretrained q and pi beta from one step repository.
"""

from collections import OrderedDict, namedtuple
from os import path as osp
from typing import List, Tuple

import torch
import torch.optim as optim
import torch.nn as nn
from rlkit import conf

from rlkit.core.loss import LossStatistics
from rlkit.data_management.offline_wrappers import load_gt_policy
from rlkit.launchers.pipeline.helpers import load_checkpoint_iql_policy
from rlkit.policies.gaussian_policy import UnnormalizeTanhGaussianPolicy
from rlkit.torch.distributions import (
    Delta,
    MultivariateDiagonalNormal,
    TanhDelta,
    TanhNormal,
)
from rlkit.core.loss import LossFunction
from rlkit.launchers.pipeline import Pipeline, PipelineCtx, Pipelines
from rlkit.torch.networks.mlp import ParallelMlp, QuantileMlp
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from rlkit.core.logging import add_prefix
from rlkit.core.logging.eval_util import create_stats_ordered_dict
import rlkit.torch.pytorch_util as ptu
import numpy as np

PACLosses = namedtuple(
    "pACLosses",
    "qfs_loss",
)


class PACTrainer(TorchTrainer, LossFunction):
    def __init__(
        self,
        policy,
        qfs,
        target_qfs,
        discount=0.99,
        reward_scale=1,
        policy_lr=0.001,
        qf_lr=0.001,
        optimizer_class=optim.Adam,
        soft_target_tau=0.01,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        # NEW PARAMS
        beta_LB=0.5,
        delta_range=None,
        num_delta=None,
        target_quantile=0.7,
        IQN=True,
    ):
        super().__init__()
        if delta_range is None:
            delta_range = [0.0, 0.0]

        self.iqn = IQN
        self.policy = policy
        self.qfs: List[QuantileMlp] = qfs
        self.target_qfs = target_qfs
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        if isinstance(qfs, nn.Module):
            self.qfs_optimizer = optimizer_class(
                self.qfs.parameters(),
                lr=qf_lr,
            )

        self.discount = discount
        self.reward_scale = reward_scale
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()

        # New params
        self.beta_LB = beta_LB
        # [1(for batch_size), num_delta, 1(for action_dim)]
        self.delta_range = delta_range

        assert (
            len(self.delta_range) == 2
        ), f"Delta range ({self.delta_range}) should be in the form of [lower_range, upper_range]!"

        if self.delta_range != [0.0, 0.0]:
            assert (
                self.delta_range[0] < self.delta_range[1]
            ), f"Delta range ({self.delta_range}) should be in the form of [lower_range, upper_range]!"

        if num_delta is None:
            num_delta = int((self.delta_range[1] - self.delta_range[0]) * 10)

        self.num_delta = num_delta
        self.target_quantile = target_quantile
        self.behavior_policy = self.policy  # for naming's sake

        print("-----------------")
        print("Delta:", delta_range)
        print("Num delta:", self.num_delta)
        print("target_quantile:", target_quantile)
        print("beta_LB:", self.beta_LB)
        print("-----------------")

    def get_action_dist(self, obs):
        return self.behavior_policy(obs)

    def get_iqn_outputs(self, obs, act):
        res = ptu.zeros(obs.shape[0], len(self.qfs))
        for i, iqn in enumerate(self.qfs):
            if self.iqn:
                res[:, i] = iqn.get_mean(obs, act)
            else: 
                res[:, i] = iqn(obs, act).squeeze()

        return res

    def get_shift_denominator(self, grad, sigma):
        # The dividor is (g^T Sigma g) ** 0.5
        # Sigma is diagonal, so this works out to be
        # ( sum_{i=1}^k (g^(i))^2 (sigma^(i))^2 ) ** 0.5
        return (
            torch.sqrt(
                torch.sum(torch.mul(torch.pow(grad, 2), sigma), dim=1, keepdim=True)
            )
            + 1e-7
        )

    def sample_delta(self):
        """
        Sample and set delta for this range.
        """
        self.delta = (
            ptu.rand(self.num_delta) * (self.delta_range[1] - self.delta_range[0])
            + self.delta_range[0]
        ).reshape(1, self.num_delta, 1)

    def get_pessimistic_action(self, obs) -> TanhDelta:

        dist: TanhNormal = self.get_action_dist(obs)
        if self.delta_range == [0.0, 0.0]:
            if self.iqn:
                return TanhDelta(dist.normal_mean)
            else:
                return Delta(dist.mean)
                

        self.sample_delta()
        if self.iqn:
            return self.compute_pessimistic_action(obs, dist)
        else:
            return self.compute_pessimistic_action_iql(obs, dist)

    def calc_q_LB(self, obs, act):
        q = self.get_iqn_outputs(obs, act)

        mu_q = q.mean(-1)
        sigma_q = q.std(-1)
        q_LB = mu_q - self.beta_LB * sigma_q
        return q_LB

    def compute_pessimistic_action(self, obs, dist: TanhNormal):
        # * preliminaries

        pre_tanh_mu_beta = dist.normal_mean
        batch_size = obs.shape[0]
        pre_tanh_mu_beta.requires_grad_()
        mu_beta = torch.tanh(pre_tanh_mu_beta)

        # * calculate gradient of q lower bound w.r.t action
        # Get the lower bound of the Q estimate
        q_LB = self.calc_q_LB(obs, mu_beta)
        # Obtain the gradient of q_LB wrt to action
        # with action evaluated at mu_proposal
        grad = torch.autograd.grad(q_LB.sum(), pre_tanh_mu_beta)[
            0
        ]

        assert grad is not None
        assert pre_tanh_mu_beta.shape == grad.shape

        # * cacluate proposals
        # Obtain Sigma_T (the covariance matrix of the normal distribution)
        Sigma_beta = torch.pow(dist.stddev, 2)

        denom = self.get_shift_denominator(grad, Sigma_beta)

        # [batch_size, num_deltas, action_dim]
        delta_mu = torch.sqrt(2 * self.delta) * (
            torch.mul(Sigma_beta, grad) / denom
        ).unsqueeze(1)

        mu_proposal = pre_tanh_mu_beta + delta_mu
        tanh_mu_proposal = torch.tanh(mu_proposal).reshape(
            batch_size * self.num_delta, -1
        )

        # * get the lower bounded q
        obs_exp = obs.repeat_interleave(self.num_delta, dim=0)
        q_LB = self.calc_q_LB(obs_exp, tanh_mu_proposal)
        q_LB = q_LB.reshape(batch_size, self.num_delta)

        # * argmax the proposals
        select_idx = q_LB.argmax(1)
        selected = mu_proposal[torch.arange(len(select_idx)), select_idx]
        return TanhDelta(selected)

    def compute_pessimistic_action_iql(self, obs, dist: MultivariateDiagonalNormal):
        # * preliminaries

        mu_beta = dist.mean
        batch_size = obs.shape[0]
        mu_beta.requires_grad_()

        # * calculate gradient of q lower bound w.r.t action
        # Get the lower bound of the Q estimate
        q_LB = self.calc_q_LB(obs, mu_beta)
        # Obtain the gradient of q_LB wrt to a
        # with a evaluated at mu_proposal
        grad = torch.autograd.grad(q_LB.sum(), mu_beta)[
            0
        ] 

        assert grad is not None
        assert mu_beta.shape == grad.shape

        # * cacluate proposals
        # Obtain Sigma_T (the covariance matrix of the normal distribution)
        Sigma_beta = torch.pow(dist.stddev, 2)

        denom = self.get_shift_denominator(grad, Sigma_beta)

        # [batch_size, num_deltas, action_dim]
        delta_mu = torch.sqrt(2 * self.delta) * (
            torch.mul(Sigma_beta, grad) / denom
        ).unsqueeze(1)

        mu_proposal = torch.clamp(mu_beta + delta_mu, -1, 1)
        mu_proposal_reshaped = torch.clamp(mu_proposal, -1, 1).reshape(
            batch_size * self.num_delta, -1
        )

        # * get the lower bounded q
        obs_exp = obs.repeat_interleave(self.num_delta, dim=0)
        q_LB = self.calc_q_LB(obs_exp, mu_proposal_reshaped)
        q_LB = q_LB.reshape(batch_size, self.num_delta)

        # * argmax the proposals
        select_idx = q_LB.argmax(1)
        selected = mu_proposal[torch.arange(len(select_idx)), select_idx]
        return Delta(selected)

    def train_from_torch(self, batch):
        raise NotImplementedError

    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[PACLosses, LossStatistics]:
        raise NotImplementedError

    def get_diagnostics(self):
        stats = super().get_diagnostics()
        stats.update(self.eval_statistics)
        return stats

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        if isinstance(self.qfs, list):
            return self.qfs + self.target_qfs + [self.policy]
        else:
            return [
                self.policy,
                self.qfs,
                self.target_qfs,
            ]

    @property
    def optimizers(self):
        return [
            self.qfs_optimizer,
            self.policy_optimizer,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            qfs=self.qfs,
            target_qfs=self.target_qfs,
        )


def sg_sanity_check(ctx: PipelineCtx):
    assert ctx.variant["checkpoint_params"] == "SG"


SGBasePipeline = Pipeline.from_(
    Pipelines.offline_zerostep_pac_pipeline, "SGBasePipeline"
)
SGBasePipeline.pipeline.insert(0, sg_sanity_check)

# * --------------------------------------------------


def variable_epoch_load_checkpoint_policy(ctx: PipelineCtx):
    params = getattr(conf.CheckpointParams, ctx.variant["checkpoint_params"])

    ctx.policy = torch.load(
        osp.join(
            conf.CheckpointParams.checkpoint_path,
            params.path,
            ctx.variant["env_id"],
            str(ctx.variant["seed"]),
            f'itr_+{ctx.variant["epoch_no"]}'.pt,
        ),
        map_location="cpu",
    )[params.key]

    if params.unnormalize:
        if (
            ctx.variant["env_id"] == "halfcheetah-medium-expert-v2"
            and ctx.variant["seed"] < 4
        ):
            pass
        else:
            ctx.policy = UnnormalizeTanhGaussianPolicy(
                ctx.obs_mean, ctx.obs_std, ctx.policy
            )


EpochBCExperiment = Pipeline.from_(SGBasePipeline, "EpochBCExperiment")
EpochBCExperiment.replace(
    "load_checkpoint_policy", variable_epoch_load_checkpoint_policy
)

# * --------------------------------------------------


def load_ground_truth_policy(ctx: PipelineCtx):

    ctx.policy = load_gt_policy(ctx.dataset)


GTExperiment = Pipeline.from_(SGBasePipeline, "GroundTruthExperiment")
GTExperiment.replace("load_checkpoint_policy", load_ground_truth_policy)

# * --------------------------------------------------
