import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.models.sbm import ScoreBaseModel
from src.objective.obj import objective_like_dpo
from src.potentials.potential_for_inner_loop import Potential_for_inner_loop
from src.utils.set_seed import set_seed

import matplotlib.pyplot as plt

import math

import gc

def main():
    print("## testing objective ##")
    set_seed(223)
    device = "cuda:0"
    model_sbm = ScoreBaseModel(2).to(device)
    model_sbm.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))
    # 並列化
    model_sbm = nn.DataParallel(model_sbm, output_device=0)
    objective_class = objective_like_dpo(device)

    import wandb
    wandb.init(project="MoG-upperbound")

    beta = 0.04

    for idx_loop in range(1, 201):
        print("idx_loop: ", idx_loop)
        Potential = ApproxPotential(2).to(device)
        potential_path = "outputs/figures/dual_averaging/20240916_105414/model_potential_"+str(idx_loop)+".pth"
        Potential.load_state_dict(torch.load(potential_path))
        Potential = nn.DataParallel(Potential, output_device=0)
        Potential.eval()    
        # 以下、目的関数の計算
        upperbound = objective_class._calc_upperbound(model_sbm, model_potential=Potential)
        print("upperbound: ", upperbound)
        loss_obj, loss_dpo, loss_kl = objective_class.reglarized_objective(Potential, model_sbm, beta)
        print("## inner-loop ended. loss: ", loss_obj, ", idx_loop: ", idx_loop, ", beta: ", beta, ", loss_dpo: ", loss_dpo, ", loss_kl: ", loss_kl)
        # wandbに保存
        wandb.log({"loss_obj": loss_obj, "loss_dpo": loss_dpo, "loss_kl": loss_kl, "loss_beta_kl": beta * loss_kl,"upperbound": upperbound, "idx_loop": idx_loop})
        del Potential
        gc.collect()
        torch.cuda.empty_cache()
    
    print("## end testing objective ##")

if __name__ == "__main__":
    main()