import math
import torch
from torch._six import inf
from torch.distributions import Normal, Categorical

def get_action(mu, std):
    action = torch.normal(mu, std)
    action = action.data.numpy()
    return action

def get_entropy(mu, std):
    dist = Normal(mu, std)
    entropy = dist.entropy().mean()
    return entropy

def log_prob_density(x, mu, std):
    log_prob_density = -(x - mu).pow(2) / (2 * std.pow(2)) \
                     - 0.5 * math.log(2 * math.pi)
    return log_prob_density.sum(1, keepdim=True)

def log_probs(x, mu, std):
    return -(x - mu).pow(2) / (2 * std.pow(2)) \
                     - 0.5 * math.log(2 * math.pi)
     
def kl_divergence(mu, logvar):
    kl_div = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1, dim=1)
    return kl_div

def save_checkpoint(state, filename):
    torch.save(state, filename)

def get_reward(discrim, state, action): #nxt_state, action):
    state = torch.Tensor(state)
    #nxt_state = torch.Tensor(nxt_state)
    action = torch.Tensor(action)
    #state_action_nxt_state = torch.cat([state, action, nxt_state])
    state_action = torch.cat((state, action)).unsqueeze(0)
    with torch.no_grad():
        return -math.log(discrim(state_action)[0].item()) #+ math.log(1. - discrim(state_action)[0].item())
        #return -math.log(discrim(state_action_nxt_state)[0].item())
    #print(discrim(state))
    #return dist.log_prob(action), dist

def get_q_value(q_value, state, action): #nxt_state, action):
    state = torch.Tensor([state])
    #nxt_state = torch.Tensor(nxt_state)
    action = torch.Tensor([action])
    #state_action = torch.cat([state, action])
    with torch.no_grad():
        return log_prob_density(action, *(q_value(state))).flatten()[0]
        #return -math.log(discrim(state_action_nxt_state)[0].item())
    #print(discrim(state))
    #return dist.log_prob(action), dist

def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device
    infinite_grad = False
    for g in grads:
        if torch.isfinite(g).all():
            continue
        else:
            infinite_grad = True
            break
    if not infinite_grad:
        
        if norm_type == 'inf':
            norms = [g.detach().abs().max().to(device) for g in grads]
        
            total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
        
        else:
            total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
        
        clip_coef = max_norm / (total_norm + 1e-6)
        # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
        # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
        # when the gradients do not reside in CPU memory.
        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
        #for g in grads:
        #    g.detach().mul_(clip_coef_clamped.to(g.device))
        return total_norm
    else:
        print("Clip Infinite grads")
        if norm_type == 'inf':
            total_norm = 1.
        else:
            total_norm = torch.norm(torch.stack([torch.norm((1. - torch.isfinite(g.detach()).float()) , norm_type).to(device) for g in grads]), norm_type)
        for g in grads:
            clip_coef = (1. - torch.isfinite(g.detach()).float()) * max_norm / (total_norm + 1e-6)
            clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
            g.data = g.data * clip_coef_clamped
    
def clip_grad_value(parameters, clip_value: float) -> None:
    r"""Clips gradient of an iterable of parameters at specified value.

    Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        clip_value (float or int): maximum allowed value of the gradients.
            The gradients are clipped in the range
            :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
    """
    #if isinstance(parameters, torch.Tensor):
    #    parameters = [parameters]
    clip_value = float(clip_value)
    for p in parameters:
        if p.grad is None:
            continue
        #if not torch.isfinite(p.grad).all():
            #print("Clipping infinite grad")
        p.grad = p.grad.nan_to_num(nan = 0.0, posinf = 1.0, neginf = -1.0)
        if not torch.isfinite(p.grad).all():
            print("Clipping infinite grad Failed???!!!")

def get_q_value(q_value, state, action): #nxt_state, action):
    state = torch.Tensor([state])
    #nxt_state = torch.Tensor(nxt_state)
    action = torch.Tensor([action])
    #state_action = torch.cat([state, action])
    with torch.no_grad():
        return log_prob_density(action, *(q_value(state))).flatten()[0]
        #return -math.log(discrim(state_action_nxt_state)[0].item())
    #print(discrim(state))
    #return dist.log_prob(action), dist
    

def save_checkpoint(state, filename):
    torch.save(state, filename)