import copy
from typing import Optional

import torch
import torch.nn as nn

from cleandiffuser.nn_classifier import BaseNNClassifier
from .base import BaseClassifier


class NewPolicyClassifier(BaseClassifier):
    """
    MSEClassifier defines logp(y | x, t) using negative MSE.
    Assuming nn_classifier is a NN used to predict y through x and t, i.e, pred_y = nn_classifier(x, t),
    logp is defined as - temperature * MSE(nn_classifier(x, t), y).
    """
    def __init__(
            self, nn_classifier, obs_dim, temperature: float = 1.0,
            ema_rate: float = 0.995, grad_clip_norm: Optional[float] = None,
            optim_params: Optional[dict] = None, device: str = "cpu"):
        super().__init__(nn_classifier, ema_rate, grad_clip_norm, optim_params, device)
        self.temperature = temperature
        self.obs_dim = obs_dim

    def modify(self, policy, task_emb):
        self.model = policy
        self.task_emb  = task_emb

    def logp(self, x: torch.Tensor, noise: torch.Tensor, y: torch.Tensor):
        # self.modify(policy, ema_policy)
        state = x[:, :-1, :]
        next_state = x[:, 1: , :]
        pi_action = self.model.pi(state, self.task_emb)[1]
        inv_action = self.model.invdyn(state, next_state)
        return -self.temperature * ((pi_action - inv_action) ** 2).mean(-1, keepdim=True)

    def gradients(self, x: torch.Tensor, noise: torch.Tensor, c: torch.Tensor):
        x.requires_grad_()
        with torch.enable_grad():
            logp = self.logp(x, noise, c)
            grad = torch.autograd.grad([logp.sum()], [x])[0]
            normalized_grad = grad / grad.norm()
            x.detach()
        return logp.detach(), normalized_grad.detach()