import torch
import ipdb
from torch.autograd import grad
from torch.distributions.categorical import Categorical
import torch.nn as nn
import numpy as np

def compute_hessian_reg(reg_type, model, sample_ndata, output, device, args, keep_gradient=False):
    name_list = get_name_list(reg_type, model, args)
    params = []
    #ipdb.set_trace()
    for name, param in model.named_parameters():
        if name not in name_list:
            continue
        params.append(param)

    reg_mean = 0.0

    if "hessianreg" in reg_type:
        # sample a set of samples from outputs
        n_samples = sample_ndata
        log_probs = output - output.logsumexp(dim=-1, keepdim=True)
        samp_dist = torch.distributions.Categorical(logits=log_probs)
        sample_y = samp_dist.sample(torch.Size((n_samples,))).view(n_samples, output.size(0), 1)
        samp_log_probs = torch.gather(log_probs.view(1, log_probs.size(0), log_probs.size(1)).repeat(n_samples, 1, 1),
                                      2,
                                      sample_y)
        try:
            samp_loss_val = -torch.sum(samp_log_probs) / output.size(0)
        except:
            ipdb.set_trace()

        # now we directly compute the regularizer here
        hl_vals = params
        j_vals = torch.autograd.grad(
            samp_loss_val,
            hl_vals,
            create_graph=keep_gradient,
            retain_graph=True)

        norm_cost = 0

        #ipdb.set_trace()
        for j_val, hl in zip(j_vals, hl_vals):
            reg_val = torch.sum(torch.pow(j_val, 2), dim=list(range(len(hl.size())))[0:])
            norm_cost += reg_val

        reg_cost = torch.mean(norm_cost) * output.size(0) / n_samples
        reg_mean = reg_cost

    elif "hessian2reg" in reg_type:
        n_samples = sample_ndata
        log_probs = output - output.logsumexp(dim=-1, keepdim=True)
        samp_dist = torch.distributions.Categorical(logits=log_probs)
        sample_y_1 = samp_dist.sample(torch.Size((n_samples,))).view(n_samples, output.size(0), 1)
        samp_log_probs_1 = torch.gather(log_probs.view(1, log_probs.size(0), log_probs.size(1)).repeat(n_samples, 1, 1),
                                      2,
                                      sample_y_1)
        samp_loss_val_1 = -torch.sum(samp_log_probs_1) / output.size(0)

        sample_y_2 = samp_dist.sample(torch.Size((n_samples,))).view(n_samples, output.size(0), 1)
        samp_log_probs_2 = torch.gather(log_probs.view(1, log_probs.size(0), log_probs.size(1)).repeat(n_samples, 1, 1),
                                      2,
                                      sample_y_2)
        samp_loss_val_2 = -torch.sum(samp_log_probs_2) / output.size(0)


        # now we directly compute the regularizer here
        hl_vals = params
        j_vals_1 = torch.autograd.grad(
            samp_loss_val_1,
            hl_vals,
            create_graph=True)
        j_vals_2 = torch.autograd.grad(
            samp_loss_val_2,
            hl_vals,
            create_graph=True)

        norm_cost = 0

        for j_val_1, j_val_2 in zip(j_vals_1, j_vals_2):
            reg_val = torch.dot(j_val_1.view(-1), j_val_2.view(-1)) ** 2
            norm_cost += reg_val
        reg_cost = torch.mean(norm_cost) * output.size(0) * output.size(0) / n_samples
        reg_mean = reg_cost

    if not keep_gradient:
        try:
            reg_mean = reg_mean.detach().item()
        except:
            reg_mean=-1.0
    return reg_mean