import torch
import numpy as np
import math
import json
import os
import gc
import torch.nn as nn

import sys
sys.path.append("/home//work/doob_apps/hug")
from src.models.CT_model_predictor import RotationPredictorCNN
from src.finetune.doob import sampling_doob
from tqdm import tqdm
import random, copy

class objective_dpo():
    '''
    DPOの目的関数
    '''
    def __init__(self, device='cuda', mode='buttefly'):
        self.device = device
        self.eps = 1e-10
        self.gamma = 1
        self.n_samples_from_ref = 4 * 8 # num of gpu の倍数
        self.preference_mode = 1
        if mode == "butterfly":
            self.target_color = (0.1, 0.1, 0.9)
        self.pot_max = 0.0
        self.pot_min = 0.0
        self.mode = mode
        self._load_config()
        self._load_data()

    def _load_config(self):
        if self.mode == "butterfly":
            config_path = "/home//work/doob_apps/hug/configs/configs.json"
        elif self.mode == "CT":
            config_path = "/home//work/doob_apps/hug/configs/configs_CT.json"
        with open(config_path, "r") as f:
            config = json.load(f)
        print("config: ", config)
        assert self.mode == config["mode"]
        self.image_size = config["image_size"]
        self.image_ref_path = config["image_ref_path"]
        self.gamma = config["gamma"]
        self.pot_max = config["pot_max"]
        self.pot_min = config["pot_min"]
        if self.mode == "butterfly":
            self.target_color = config["preferred_color"]
            self.target_color = (self.target_color[0], self.target_color[1], self.target_color[2])
        elif self.mode == "CT":
            self.predictor = RotationPredictorCNN()
            self.predictor.load_state_dict(torch.load(config["predictor_path"]))
            self.predictor.to(self.device)
            self.predictor.eval()
    def _load_data(self):
        temp_images_ref = torch.load(self.image_ref_path) # torch.randn(self.n_samples_from_ref, 3, self.image_size, self.image_size).to(self.device)
        # 最初のn_samples_from_ref個ではなく, n_samples_from_ref個の中からランダムに選ぶ
        self.images_ref = temp_images_ref[torch.randperm(self.n_samples_from_ref)].to(self.device)
        del temp_images_ref
        gc.collect()
        torch.cuda.empty_cache()

    def convert_to_tensor(self, x):
        if isinstance(x, np.ndarray):
            return torch.tensor(x).to(self.device)
        return x
    
    def _color_loss(self, images):
        """Given a target color (R, G, B) return a loss for how far away on average
        the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
        """
            images: (batch_size, channel, height, width)
            error: (batch_size)
        """
        target = (
            torch.tensor(self.target_color).to(images.device) * 2 - 1
        )  # Map target color to (-1, 1)
        target = target[
            None, :, None, None
        ]  # Get shape right to work with the images (b, c, h, w)
        # error = torch.abs(
        #     images - target
        # ).mean()  # Mean absolute difference between the image pixels and the target color
        error = torch.abs(images - target).mean(dim=(1, 2, 3))
        return error
    
    def _CT_predictor_loss(self, images):
        """
            images: (batch_size, channel, height, width)
            loss: (batch_size)
        """
        with torch.no_grad():
            pred = torch.abs(self.predictor(images).detach())
        return pred

    def _preference(self, images_1, images_2):
        """Given two sets of images, return a preference score for each pair"""
        """
            images_1: (batch_size, channel, height, width)
            images_2: (batch_size, channel, height, width)
            preference: (batch_size)
        """
        if self.mode == "butterfly":
            # Compute the color loss for each set of images
            color_loss_1 = self._color_loss(images_1)
            color_loss_2 = self._color_loss(images_2)
            # Compute the preference score
            preference = color_loss_1 - color_loss_2
            return_vec = (preference < 0.0).float()
            del preference, color_loss_1, color_loss_2
            gc.collect()
            torch.cuda.empty_cache()
        elif self.mode == 'CT':
            pred_1 = self._CT_predictor_loss(images_1)
            pred_2 = self._CT_predictor_loss(images_2)
            preference = pred_1 - pred_2
            return_vec = (preference < 0.0).float().squeeze()
            del preference, pred_1, pred_2
            gc.collect()
            torch.cuda.empty_cache()
        return return_vec

    def _sigmoidLog(self, log_ratio_win, log_ratio_lose):
        eps = self.eps
        gamma = self.gamma
        diff = -gamma * (log_ratio_win - log_ratio_lose)
        clamped_diff = torch.clamp(diff, min=-10, max=10)
        return_vec = 1 / (1 + torch.exp(clamped_diff))
        del diff, clamped_diff
        gc.collect()
        torch.cuda.empty_cache()
        return return_vec

    def _LogSigmoidLog(self, log_ratio_win, log_ratio_lose):
        return torch.log(self._sigmoidLog(log_ratio_win, log_ratio_lose)+self.eps)

    def _deriv_LogSigmoidLog_mult_by_p(self, log_ratio_win, log_ratio_lose, wrt = "win"):
        gamma = self.gamma
        sigmoid = self._sigmoidLog(log_ratio_win, log_ratio_lose)
        if wrt == "win":
            return_tensor = gamma * (1 - sigmoid)
            del sigmoid
            gc.collect()
            torch.cuda.empty_cache()
            return return_tensor
        elif wrt == "lose":
            return_tensor = - gamma * (1 - sigmoid)
            del sigmoid
            gc.collect()
            torch.cuda.empty_cache()
            return return_tensor

    def _expectation(self, x, model_potential, for_what, beta = 1):
        """
            x: (batch_size, 3, 32, 32)
            model_potential: potential model
        """
        batch_size = x.shape[0]
        samples_from_ref = self.images_ref
        n_samples = samples_from_ref.size(0)
        # x, samples_from_sbm を複製して, (batch_size * n_samples, 3, 32, 32) にする
        x = x.repeat_interleave(len(samples_from_ref), dim = 0) # [x[0], x[0], ..., x[0], x[1], x[1], ..., x[1], ...]
        # print("x: ", x.shape)
        if self.mode == "butterfly":
            assert x[0, 0, 0, 0] == x[n_samples-1, 0, 0, 0] and x[0, 1, 0, 0] == x[n_samples-1, 1, 0, 0]
        elif self.mode == "CT":
            assert x[0, 0, 0, 0] == x[n_samples-1, 0, 0, 0]
        samples = samples_from_ref.repeat(batch_size, 1, 1, 1) # [samples[0], samples[1], ..., samples[0], samples[1], ...]
        if self.mode == "butterfly":
            assert samples[0, 0, 0, 0] == samples[n_samples, 0, 0, 0] and samples[0, 1, 0, 0] == samples[n_samples, 1, 0, 0]
        elif self.mode == "CT":
            assert samples[0, 0, 0, 0] == samples[n_samples, 0, 0, 0]
        # print("samples_from_ref: ", samples_from_ref.shape)
        # print("samples: ", samples.shape)
        ## 和を計算していく ##
        if for_what == "potential":
            # images_refをランダムにシャッフルし, images_ref_2とする
            # samples_from_ref_2 = samples_from_ref[torch.randperm(n_samples)].to(self.device)
            p_bt_wrt_x2 = self._preference(x, samples)
            p_bt_wrt_x1 = self._preference(samples, x)
            # print("p_bt_wrt_x2 shape: ", p_bt_wrt_x2.shape)
            # print("p_bt_wrt_x1 shape: ", p_bt_wrt_x1.shape)
            # log ratio = - potential であることに注意
            # print("model_potential(x): ", model_potential(x).shape)
            # print("model_potential(samples): ", model_potential(samples_from_ref).shape)
            # samples_from_refを細かく分割して, model_potentialを計算
            mini_batch_size = x.shape[0] // 8
            print("batch_size: ", batch_size)
            print("n_samples: ", n_samples)
            print("mini_batch_size: ", mini_batch_size)

            for i in range(0, samples.shape[0], mini_batch_size):
                # print("i: ", i)
                if i + mini_batch_size > samples.shape[0]:
                    break
                with torch.no_grad():
                    pot_ref_i = model_potential(samples[i:i+mini_batch_size]).detach()
                # exp_pot_refにcat
                if i == 0:
                    pot_ref = pot_ref_i
                else:
                    pot_ref = torch.cat([pot_ref, pot_ref_i], dim = 0)
                    del pot_ref_i
                    gc.collect()
                    torch.cuda.empty_cache()
            exp_pot_ref = torch.exp(-pot_ref)
            # exp_pot_invも同様にcat
            for i in range(0, x.shape[0], mini_batch_size):
                if i + mini_batch_size > x.shape[0]:
                    break
                with torch.no_grad():
                    pot_x_i = model_potential(x[i:i+mini_batch_size]).detach()
                if i == 0:
                    pot_x = pot_x_i
                else:
                    pot_x = torch.cat([pot_x, pot_x_i], dim = 0)
                    del pot_x_i
                    gc.collect()
                    torch.cuda.empty_cache()
            exp_pot_inv = torch.exp(pot_x)

            # x, samples, samples_from_refは削除
            del x, samples, samples_from_ref
            gc.collect()
            torch.cuda.empty_cache()

            if self.mode == "CT":
                ###### squeeze ######
                pot_x = pot_x.squeeze(1)
                pot_ref = pot_ref.squeeze(1)
                exp_pot_ref = exp_pot_ref.squeeze(1)
                exp_pot_inv = exp_pot_inv.squeeze(1)

            assert pot_x.shape == (batch_size * n_samples,)
            assert pot_ref.shape == (batch_size * n_samples,)
            assert exp_pot_ref.shape == (batch_size * n_samples,)
            assert exp_pot_inv.shape == (batch_size * n_samples,)
            
            exp_pot_inv = torch.clamp(exp_pot_inv, min = math.exp(self.pot_min), max = math.exp(self.pot_max))
            # exp_pot_refの値がtorch.exp(-self.pot_min)を超える場合は, その値にする
            exp_pot_ref = torch.clamp(exp_pot_ref, min = math.exp(-self.pot_max), max = math.exp(-self.pot_min))

            smp_wrt_x2  = self._deriv_LogSigmoidLog_mult_by_p(-pot_x, -pot_ref, wrt = "win") \
                            * exp_pot_inv * p_bt_wrt_x2
            smp_wrt_x1  = self._deriv_LogSigmoidLog_mult_by_p(-pot_x, -pot_ref, wrt = "lose") \
                            * exp_pot_inv * p_bt_wrt_x1
            
            # smp_wrt_x2が全て0以上であることを確認
            assert torch.sum(smp_wrt_x2 < -1e-4) == 0
            # smp_wrt_x2の最小値を0でclip
            smp_wrt_x2 = torch.clamp(smp_wrt_x2, min = 0.0)
            # smp_wrt_x1が全て0以下であることを確認
            assert torch.sum(smp_wrt_x1 > 1e-4) == 0
            smp_wrt_x1 = torch.clamp(smp_wrt_x1, max = 0.0)

            assert smp_wrt_x2.shape == (batch_size * n_samples,)
            smp_wrt_x2_n_samples = smp_wrt_x2[n_samples]
            chunks_wrt_x2 = smp_wrt_x2.view(batch_size, n_samples)

            assert chunks_wrt_x2[1,0] == smp_wrt_x2_n_samples
            chunks_wrt_x1 = smp_wrt_x1.view(batch_size, n_samples)

            sums_wrt_x2 = torch.sum(chunks_wrt_x2, dim=1)
            sums_wrt_x1 = torch.sum(chunks_wrt_x1, dim=1)

            ## exp_pot_avgをかける ##
            exp_pot_avg = torch.mean(exp_pot_ref)
            # print("exp_pot_sums: ", exp_pot_avg)
            sums_wrt_x2 = sums_wrt_x2 * exp_pot_avg
            sums_wrt_x1 = sums_wrt_x1 * exp_pot_avg

            avg_wrt_x2 = sums_wrt_x2 / n_samples
            avg_wrt_x1 = sums_wrt_x1 / n_samples

            average = avg_wrt_x2 + avg_wrt_x1

            del smp_wrt_x2, smp_wrt_x1, smp_wrt_x2_n_samples, chunks_wrt_x2, chunks_wrt_x1, sums_wrt_x2, sums_wrt_x1, \
                exp_pot_avg, exp_pot_inv, exp_pot_ref, pot_x, pot_ref, avg_wrt_x2, avg_wrt_x1
            
            gc.collect()
            torch.cuda.empty_cache()

        elif for_what == "objective": # returns (1)
            samples_from_ref_2 = samples_from_ref[torch.randperm(n_samples)].to(self.device) # (n_samples, input_dim)
            lsl = self._LogSigmoidLog(-model_potential(samples_from_ref), -model_potential(samples_from_ref_2))
            lsl = lsl.squeeze(1) # これがないとバグる
            p_bt = self._p_bt(samples_from_ref, samples_from_ref_2)
            count_p_bt = torch.sum(p_bt)
            sum = lsl * p_bt
            # average = torch.sum(sum) / count_p_bt
            average = torch.sum(sum) / n_samples

        elif for_what == "regularized_objective": # returns (1)
            samples_from_ref_2 = samples_from_ref[torch.randperm(n_samples)].to(self.device) # (n_samples, input_dim)
            
            with torch.no_grad():
                pot_ref_1 = model_potential(samples_from_ref).detach()
                pot_ref_2 = model_potential(samples_from_ref_2).detach()

            if self.mode == "CT":
                pot_ref_1 = pot_ref_1.squeeze(1)
                pot_ref_2 = pot_ref_2.squeeze(1)

            # objective of DPO
            lsl = self._LogSigmoidLog(-pot_ref_1, -pot_ref_2)
            p_bt = self._preference(samples_from_ref, samples_from_ref_2)
            count_p_bt = torch.sum(p_bt)
            if count_p_bt == 0:
                print("count_p_btが0です.")
                count_p_bt += 1
            sum = lsl * p_bt
            average = torch.sum(sum) / n_samples

            # KLダイバージェンスの計算
            pot_ref = pot_ref_1.to(self.device)
            # 正規化定数
            Z = torch.sum(torch.exp(-pot_ref)) / samples_from_ref.size(0)
            Z = Z.item()
            # Zをpot_refと同じ形にする
            Z_tensor = torch.tensor([Z]).to(self.device)
            Z_tensor = Z_tensor.repeat(samples_from_ref.size(0)).to(self.device)
            assert Z_tensor.shape == pot_ref.shape
            # pot_refと同じ形の1のテンソルを作成
            KL_tensor = (- pot_ref - torch.log(Z_tensor))  * torch.exp( - pot_ref) / Z
            print("KL_tensor: ", KL_tensor.shape)
            KL = torch.sum(KL_tensor) / samples_from_ref.size(0)

            print("(not negated) DPO: ", average)
            print("KL: ", KL)
            avg_and_kl = average - beta * KL

            del lsl, p_bt, count_p_bt, sum, Z, Z_tensor, KL_tensor
            gc.collect()
            torch.cuda.empty_cache()

            return avg_and_kl, average, KL
        return average
    
    def potential(self, x, model_potential): 
        self._load_data()
        # x: (batch_size, input_dim)
        x = self.convert_to_tensor(x)
        # -∂F/∂p を 返す
        return - self._expectation(x, model_potential, "potential")
        
    def objective(self, model_potential):
        # x は使わない
        if self.mode == 'butterfly':
            x = torch.randn(2,3,32,32).to(self.device)
        elif self.mode == 'CT':
            x = torch.randn(2,1,64,64).to(self.device)
        # -F を返す (-Fの最小化はFの最大化と同じ)
        return - self._expectation(x, model_potential, "objective")

    def reglarized_objective(self, model_potential, beta = 1):
        # x は使わない
        self._load_data()
        if self.mode == 'butterfly':
            x = torch.randn(2,3,32,32).to(self.device)
        elif self.mode == 'CT':
            x = torch.randn(2,1,64,64).to(self.device)
        avg_and_kl, average, KL = self._expectation(x, model_potential, "regularized_objective", beta)
        return -avg_and_kl, -average, KL
    
    ###### upperbound ######

    def _sampling_ref(batch_size, model_sbm, noise_scheduler, device, mode):
        raise NotImplementedError("not implemented")
        if mode == "butterfly":
            colors = 3
            size = 32
        elif mode == "CT":
            colors = 1
            size = 16
        sample = torch.randn(batch_size, colors, size, size).to(device)
        for i, t in enumerate(noise_scheduler.timesteps):
            # Get model pred
            with torch.no_grad():
                residual = unet(sample, t)["sample"]  # model(sample, t).sample
            # Update sample with step
            sample = noise_scheduler.step(residual, t, sample).prev_sample

    def _sample_noised_images(self, model_sbm, noise_scheduler, config_doob, mode, batch_size, t, autoencoder=None):
        # p_refからサンプル
        # x_T = new_sbm_based_langevin_monte_carlo(model_sbm, batch_size, self.sample_num_steps, T=10, device=self.device)
        # x_T = _sampling_ref(batch_size, model_sbm, noise_scheduler, self.device, mode)
        with torch.no_grad():
            # self.images_refをshuffled_images_refに変換
            shuffled_images_ref = self.images_ref[torch.randperm(self.n_samples_from_ref)].to(self.device)
            x_T = autoencoder.encoder(shuffled_images_ref[:batch_size]).detach().to(self.device)
            del shuffled_images_ref
            gc.collect()
            torch.cuda.empty_cache()
        # p_refのサンプルにノイズを加える
        noise = torch.randn_like(x_T).to(self.device)
        # t = torch.tensor(t).to(self.device)
        # tを 100倍して, intに変換
        t = int(t * 100)
        # alpha_bar = noise_scheduler.alphas_cumprod[t]
        # x_t = torch.exp(-t) * x_T + torch.sqrt(1 - torch.exp(-2*t)) * noise
        # x_t = torch.sqrt(alpha_bar) * x_T + torch.sqrt(1 - alpha_bar) * noise
        x_t = noise_scheduler.add_noise(x_T, noise, torch.tensor(t, device=self.device))
        # x_t requires grad
        x_t.requires_grad = True
        # メモリ解放
        del x_T
        gc.collect()
        torch.cuda.empty_cache()
        return x_t, noise

    def _predict_noise(self, model_sbm, noise_scheduler, config_doob, mode ,model_potential, t, x, autoencoder=None):
        print("oops! this function is not implemented yet.")
        raise NotImplementedError("not implemented")
        # x: (batch_size, input_dim)
        # ノイズの推定値を計算        
        batch_t = torch.ones((x.shape[0], 1)).to(self.device) * t
        ## x: (num_samples, input_dim)
        doob = doob_h_tensorized(model_sbm, model_potential, x, t, step_size=10/self.sample_num_steps)
        doob = torch.stack([doob[0], doob[1]], axis=1)
        with torch.no_grad():
            score_pred = model_sbm(torch.cat([x, batch_t], axis=1))
        t = torch.tensor(t).to(self.device)
        coeff = torch.sqrt(1 - torch.exp(-2*t))
        noise_pred_ref = - coeff * score_pred
        noise_pred_doob = - coeff * (doob+score_pred)
        ## 解放
        del doob, score_pred
        gc.collect()
        torch.cuda.empty_cache()
        return noise_pred_doob, noise_pred_ref

    def _predict_noise_training(self, model_sbm, model_sbm_training, noise_scheduler, config_doob, mode, t, x):
        if mode == "butterfly":
            colors = 3
            size = 32
        elif mode == "CT":
            colors = 1
            size = 32
        batch_t = torch.ones((x.shape[0], 1)).to(self.device) * t
        pred_concat = []
        pred_concat_training = []
        sub_batch_size = 2
        # xをsub_batch_sizeずつに分割
        x_concat = torch.split(x, sub_batch_size)
        alpha_bar = noise_scheduler.alphas_cumprod[int(t*100)]
        for i in range(len(x_concat)):
            # print(f"i: {i}, ", end="")
            x_i = x_concat[i]
            # print(f"x_i = {x_i.shape}")
            with torch.no_grad():
                sample_pred = model_sbm(x_i, torch.tensor(t, device=self.device), return_dict=False)[0].detach()
                noise_pred = (x_i - torch.sqrt(alpha_bar) * sample_pred) / torch.sqrt(1 - alpha_bar)
            sample_pred_training = model_sbm_training(x_i, torch.tensor(t, device=self.device), return_dict=False)[0]
            noise_pred_training = (x_i - torch.sqrt(alpha_bar) * sample_pred_training) / torch.sqrt(1 - alpha_bar)
            pred_concat.append(noise_pred)
            pred_concat_training.append(noise_pred_training)
            del noise_pred, noise_pred_training
            gc.collect()
            torch.cuda.empty_cache()
        pred_concat = torch.cat(pred_concat, axis=0)
        pred_concat_training = torch.cat(pred_concat_training, axis=0)
        # for i in range(len(x_concat)):
        #    with torch.no_grad():
        #          score_pred = model_sbm(x, t).sample # model_sbm(torch.cat([x, batch_t], axis=1))
        #     score_pred_training = model_sbm(x, t).sample # model_sbm_training(torch.cat([x, batch_t], axis=1))
        # t = torch.tensor(t).to(self.device)
        # t = int(t * 100)
        # coeff = torch.sqrt(1 - torch.exp(-2*t))
        # coeff = torch.sqrt(1 - noise_scheduler.alphas_cumprod[t])
        # noise_pred_ref = - coeff * pred_concat
        # noise_pred_doob = - coeff * pred_concat_training
        del x_concat
        gc.collect()
        torch.cuda.empty_cache()
        return pred_concat_training, pred_concat

    def _calc_upperbound(self, model_sbm, noise_scheduler, config_doob, mode, autoencoder=None, model_potential=None, model_sbm_training=None):
        sum = 0
        num_exp = 5
        batch_size = 10 # 250
        for i in tqdm(range(1,num_exp+1)):
            ## t = 0 ~ 10のランダムなfloat
            t = random.uniform(0.01, 10)
            # cuda memoryの全体量に対する使用量を表示
            print("cuda memory 475: ", int((torch.cuda.memory_allocated() / 1024 / 1024 / 1024)))
            print("t: ", t)
            # 時刻tに対して, batch_size分のサンプルを取得
            # _sample_noised_imagesを使って, ノイズを出力
            x_t_w, noise_w = self._sample_noised_images(model_sbm, noise_scheduler, config_doob, mode, batch_size, t, autoencoder)
            x_t_l, noise_l = self._sample_noised_images(model_sbm, noise_scheduler, config_doob, mode, batch_size, t, autoencoder)
            # ノイズの推定値を計算
            # print("cuda memory 482: ", int((torch.cuda.memory_allocated() / 1024 / 1024 / 1024)))
            if model_potential is not None:
                assert model_sbm_training is None
                # 実装していないのでnot implemented
                raise NotImplementedError("not implemented")
                # pred_noise_w_doob, pred_noise_w_ref = self._predict_noise(model_sbm, noise_scheduler, config_doob, mode, autoencoder=None, model_potential, t, x_t_w)
                # pred_noise_l_doob, pred_noise_l_ref = self._predict_noise(model_sbm, noise_scheduler, config_doob, mode, autoencoder=None, model_potential, t, x_t_l)
            else:
                assert model_sbm_training is not None
                pred_noise_w_doob, pred_noise_w_ref = self._predict_noise_training(model_sbm, model_sbm_training, noise_scheduler, config_doob, mode, t, x_t_w)
                pred_noise_l_doob, pred_noise_l_ref = self._predict_noise_training(model_sbm, model_sbm_training, noise_scheduler, config_doob, mode, t, x_t_l)
            # print("cuda memory: ", int((torch.cuda.memory_allocated() / 1024 / 1024 / 1024)))
            error_w = torch.linalg.norm(pred_noise_w_doob - noise_w, dim=(1,2,3)) - torch.linalg.norm(pred_noise_w_ref - noise_w, dim=(1,2,3))
            error_l = torch.linalg.norm(pred_noise_l_doob - noise_l, dim=(1,2,3)) - torch.linalg.norm(pred_noise_l_ref - noise_l, dim=(1,2,3))
            sigmoid = self._LogSigmoidLog(error_w, error_l)
            # print("sigmoid: ", sigmoid.shape)
            # p_bt
            # p_bt = self._p_bt(x_t_w, x_t_l)
            if mode == 'CT':
                with torch.no_grad():
                    x_t_w = autoencoder.decoder(x_t_w)
                    x_t_l = autoencoder.decoder(x_t_l)
            p_bt = self._preference(x_t_w, x_t_l)
            sum += torch.sum(p_bt * sigmoid)
            del x_t_w, x_t_l, noise_w, noise_l, pred_noise_w_doob, pred_noise_w_ref, \
                pred_noise_l_doob, pred_noise_l_ref, error_w, error_l, sigmoid, p_bt
            gc.collect()
            torch.cuda.empty_cache()
            # print("cuda memory: ", int((torch.cuda.memory_allocated() / 1024 / 1024 / 1024)))
        if model_potential is not None:
            avg = - sum.item() / num_exp / batch_size
        else:
            assert model_sbm_training is not None
            avg = - sum / num_exp / batch_size
        return avg

