"""
Implicit diffusion Q-learning (IDQL) for diffusion policy.

"""

import logging
import torch
import einops
import copy

import torch.nn.functional as F

log = logging.getLogger(__name__)

from onpolicy.algorithms.diffusion_ac.dppo.diffusion_rwr import RWRDiffusion


def expectile_loss(diff, expectile=0.8):
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)


class IDQLDiffusion(RWRDiffusion):

    def __init__(
        self,
        actor,
        critic_q,
        critic_v,
        **kwargs,
    ):
        super().__init__(network=actor, **kwargs)
        self.critic_q = critic_q.to(self.device)
        self.target_q = copy.deepcopy(critic_q)
        self.critic_v = critic_v.to(self.device)

        # assign actor
        self.actor = self.network

    # ---------- RL training ----------#

    def compute_advantages(self, obs, actions):

        # get current Q-function, stop gradient
        with torch.no_grad():
            current_q1, current_q2 = self.target_q(obs, actions)
        q = torch.min(current_q1, current_q2)

        # get the current V-function
        v = self.critic_v(obs).reshape(-1)

        # compute advantage
        adv = q - v
        return adv

    def loss_critic_v(self, obs, actions):
        adv = self.compute_advantages(obs, actions)

        # get the value loss
        v_loss = expectile_loss(adv).mean()
        return v_loss

    def loss_critic_q(self, obs, next_obs, actions, rewards, terminated, gamma):

        # get current Q-function
        current_q1, current_q2 = self.critic_q(obs, actions)

        # get the next V-function, stop gradient
        with torch.no_grad():
            next_v = self.critic_v(next_obs)

        # terminal state mask
        mask = 1 - terminated

        # flatten
        rewards = rewards.view(-1)
        next_v = next_v.view(-1)
        mask = mask.view(-1)

        # target value
        discounted_q = rewards + gamma * next_v * mask

        # Update critic
        q_loss = torch.mean((current_q1 - discounted_q) ** 2) + torch.mean(
            (current_q2 - discounted_q) ** 2
        )
        return q_loss

    def update_target_critic(self, tau):
        for target_param, source_param in zip(
            self.target_q.parameters(), self.critic_q.parameters()
        ):
            target_param.data.copy_(
                target_param.data * (1.0 - tau) + source_param.data * tau
            )

    # override
    def p_losses(
        self,
        x_start,
        cond,
        t,
    ):
        """not reward-weighted, same as diffusion.py"""
        device = x_start.device

        # Forward process
        noise = torch.randn_like(x_start, device=device)
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        # Predict
        x_recon = self.network(x_noisy, t, cond=cond)

        # Loss with mask
        if self.predict_epsilon:
            loss = F.mse_loss(x_recon, noise)
        else:
            loss = F.mse_loss(x_recon, x_start)
        return loss.mean()

    # ---------- Sampling ----------#``

    # override
    @torch.no_grad()
    def forward(
        self,
        cond,
        deterministic=False,
        num_sample=10,
        critic_hyperparam=0.7,  # sampling weight for implicit policy
        use_expectile_exploration=True,
    ):
        """assume state-only, no rgb in cond"""
        # repeat obs num_sample times along dim 0
        cond_shape_repeat_dims = tuple(1 for _ in cond["state"].shape)
        B, T, D = cond["state"].shape
        S = num_sample
        cond_repeat = cond["state"][None].repeat(num_sample, *cond_shape_repeat_dims)
        cond_repeat = cond_repeat.view(-1, T, D)  # [B*S, T, D]

        # for eval, use less noisy samples --- there is still DDPM noise, but final action uses small min_sampling_std
        samples = super(IDQLDiffusion, self).forward(
            {"state": cond_repeat},
            deterministic=deterministic,
        )
        _, H, A = samples.shape

        # get current Q-function
        current_q1, current_q2 = self.target_q({"state": cond_repeat}, samples)
        q = torch.min(current_q1, current_q2)
        q = q.view(S, B)

        # Use argmax
        if deterministic or (not use_expectile_exploration):
            # gather the best sample -- filter out suboptimal Q during inference
            best_indices = q.argmax(0)
            samples_expanded = samples.view(S, B, H, A)

            # dummy dimension @ dim 0 for batched indexing
            sample_indices = best_indices[None, :, None, None]  # [1, B, 1, 1]
            sample_indices = sample_indices.repeat(S, 1, H, A)

            samples_best = torch.gather(samples_expanded, 0, sample_indices)
        # Sample as an implicit policy for exploration
        else:
            # get the current value function for probabilistic exploration
            current_v = self.critic_v({"state": cond_repeat})
            v = current_v.view(S, B)
            adv = q - v

            # Compute weights for sampling
            samples_expanded = samples.view(S, B, H, A)

            # expectile exploration policy
            tau_weights = torch.where(adv > 0, critic_hyperparam, 1 - critic_hyperparam)
            tau_weights = tau_weights / tau_weights.sum(0)  # normalize

            # select a sample from DP probabilistically -- sample index per batch and compile
            sample_indices = torch.multinomial(tau_weights.T, 1)  # [B, 1]

            # dummy dimension @ dim 0 for batched indexing
            sample_indices = sample_indices[None, :, None]  # [1, B, 1, 1]
            sample_indices = sample_indices.repeat(S, 1, H, A)

            samples_best = torch.gather(samples_expanded, 0, sample_indices)

        # squeeze dummy dimension
        samples = samples_best[0]
        return samples