import torch
import torch.nn as nn
import torch.nn.functional as F


def kl_divergence_loss(student_output, teacher_output):
    student_output = student_output.view(student_output.size(0), -1)
    teacher_output = teacher_output.view(teacher_output.size(0), -1)
    student_log_probs = F.log_softmax(student_output, dim=1)
    teacher_probs = F.softmax(teacher_output, dim=1)
    kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
    return kl_loss


def noise_mse_loss(teacher_output, ground_truth_noise):
    return nn.functional.mse_loss(teacher_output, ground_truth_noise)


def cfg_compute(student_pred_uncond, student_pred_cond, guidance_scale):
    student_pred_final = student_pred_uncond + guidance_scale * (student_pred_cond - student_pred_uncond)
    return student_pred_final


def scg_loss_compute(student_output_uncond, teacher_output):
    student_output_uncond = student_output_uncond.view(student_output_uncond.size(0), -1)
    teacher_output = teacher_output.view(teacher_output.size(0), -1)
    student_log_probs = F.log_softmax(student_output_uncond, dim=1)
    teacher_probs = F.softmax(teacher_output, dim=1)
    kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
    return kl_loss


def mse_loss_compute(student_output, ground_truth_noise):
    return nn.functional.mse_loss(student_output, ground_truth_noise)