import torch
from torch.nn import BCEWithLogitsLoss
from functions.denoising import compute_alpha
import diffusers
from diffusers import DDIMPipeline
from diffusers.models.unets.unet_2d import UNet2DOutput

import types

def noise_estimation_loss(model,
                          x0: torch.Tensor,
                          t: torch.LongTensor,
                          e: torch.Tensor,
                          b: torch.Tensor, keepdim=False):
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model.forward(x, t.float())
    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3))
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)

def inference_attack_loss(model,
                          x0: torch.Tensor,
                          t: torch.LongTensor,
                          e: torch.Tensor,
                          b: torch.Tensor, classifier, args, keepdim=False):

    at = compute_alpha(b, t.long())
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    #from the denoising code
    et = model(x, t)
    x_t = x.detach().requires_grad_(True)
    probs = torch.sigmoid(classifier(x_t, t))
    #To ensure that we don't get nan gradients
    EPSILON = 0.0001
    with torch.enable_grad():
        #when probability of category 0 is high
        #take a negative gradient wrt probability of category 0
        grad_0 = (1-at).sqrt()*torch.autograd.grad(torch.log(torch.ones_like(probs)-probs + EPSILON), x_t, torch.ones_like(probs), retain_graph = True)[0]*args.scale
        #when probability of category 1 is high
        #take a negative gradient in the direction that the category is 1
        grad_1 = (1-at).sqrt()*torch.autograd.grad(torch.log(probs), x_t, torch.ones_like(probs + EPSILON))[0]*args.scale
    
    high_prob_of_0= (probs<args.tolerance).int()
    high_prob_of_1 = (probs>1-args.tolerance).int()
    correction_factor_0 = high_prob_of_0.unsqueeze(-1).unsqueeze(-1)*grad_1
    correction_factor_1 = high_prob_of_1.unsqueeze(-1).unsqueeze(-1)*grad_0
    correction_factor = correction_factor_0 + correction_factor_1
    print('prob 1:', high_prob_of_1.sum(), 'prob 0:', high_prob_of_0.sum())
    print('norm is', torch.norm(correction_factor))

    output = et-torch.nan_to_num(correction_factor, 0) #hard coded for now
    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3))
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)


def binary_classification_loss(model, x0, t, e, b, y, keepdim=False):
    a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model(x, t.float())  # [B, 1]
    loss = BCEWithLogitsLoss(reduction='none')(output, y.float())
    if keepdim:
        return loss
    else:
        return loss.mean()


loss_registry = {
    'simple': noise_estimation_loss,
    'classification': binary_classification_loss,
    'inference_attack': inference_attack_loss
}
