import copy
from typing import Optional

import torch
import torch.nn as nn
from cleandiffuser.nn_classifier import BaseNNClassifier
from .base import BaseClassifier


class PolicyClassifier(BaseClassifier):
    """
    MSEClassifier defines logp(y | x, t) using negative MSE.
    Assuming nn_classifier is a NN used to predict y through x and t, i.e, pred_y = nn_classifier(x, t),
    logp is defined as - temperature * MSE(nn_classifier(x, t), y).
    """
    def __init__(
            self, nn_classifier, obs_dim, temperature: float = 1.0,
            ema_rate: float = 0.995, grad_clip_norm: Optional[float] = None,
            optim_params: Optional[dict] = None, device: str = "cpu"):
        super().__init__(nn_classifier, ema_rate, grad_clip_norm, optim_params, device)
        self.temperature = temperature
        self.obs_dim = obs_dim

    def modify(self, policy, task_emb):
        self.model = policy
        self.task_emb  = task_emb

    def logp_mse(self, x: torch.Tensor, noise: torch.Tensor, y: torch.Tensor):
        # self.modify(policy, ema_policy)
        state = x[..., :self.obs_dim].detach()
        action = x[..., self.obs_dim:-2]
        pi_action = self.model.pi(state, self.task_emb)[1]
        return -self.temperature * ((pi_action - action) ** 2).mean(-1, keepdim=True)

    def logp(self, x: torch.Tensor, noise: torch.Tensor, y: torch.Tensor):
        # self.modify(policy, ema_policy)
        state = x[..., :self.obs_dim].detach()
        action = x[..., self.obs_dim:-2]
        # mu, pi, _, log_std = self.model.pi(state, self.task_emb)
        # pi_dist = torch.distributions.Normal(mu, log_std.exp())
        # return self.temperature * pi_dist.log_prob(action).sum(dim=-1)
        pi = torch.tensor(self.model.actor(state))
        pi_dist = torch.distributions.Normal(pi, 1.0)
        return self.temperature * pi_dist.log_prob(action).sum(dim=-1)

    def gradients(self, x: torch.Tensor, noise: torch.Tensor, c: torch.Tensor):
        x.requires_grad_()
        with torch.enable_grad():
            # logp = self.logp_mse(x, noise, c)
            logp = self.logp(x, noise, c)
            grad = torch.autograd.grad([logp.sum()], [x])[0]
            # grad = torch.nn.utils.clip_grad_norm()
            normalized_grad = grad / grad.norm()
            x.detach()
        return logp.detach(), normalized_grad.detach()