import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime
import json

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.models.sbm import ScoreBaseModel
from src.objective.obj import objective_like_dpo
from src.potentials.potential_for_inner_loop import Potential_for_inner_loop
from src.utils.set_seed import set_seed
from src.models.init_models import xavier_init

import matplotlib.pyplot as plt

def main():
    print("## testing inner loop ##")

    # set_seed(111) # 逆ぅ
    set_seed(222)
    device = "cuda"
    idx_loop = 1 # 2
    config_reg_path = 'configs/regularization.json'
    with open(config_reg_path) as f:
        config_reg = json.load(f)
    beta = config_reg["beta"]
    gamma = config_reg["gamma"]
    y_min = config_reg["y_min"]
    y_max = config_reg["y_max"]
    pot_min = config_reg["pot_min"]
    pot_max = config_reg["pot_max"]
    print("idx_loop: ", idx_loop, ", beta: ", beta)  

    # セーブ用のディレクトリを作成
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = "outputs/figures/test_inner_loop/"+now_str
    os.makedirs(dirname, exist_ok=True)

    objective_class = objective_like_dpo(device) 
    p_bt_mode = objective_class.p_bt_mode

    # idx_loopとbetaをtxtファイルに保存
    with open(dirname+'/idx_loop_beta.txt', mode='w') as f:
        f.write("idx_loop: "+str(idx_loop)+"\n")
        f.write("beta: "+str(beta)+"\n")
        f.write("gamma: "+str(gamma)+"\n")
        f.write("y_min: "+str(y_min)+"\n")
        f.write("y_max: "+str(y_max)+"\n")
        f.write("pot_min: "+str(pot_min)+"\n")
        f.write("pot_max: "+str(pot_max)+"\n")
        f.write("p_bt_mode: "+str(p_bt_mode)+"\n")

    # previousPotential = ApproxPotential(2).to(device)

    previousPotential = ApproxPotential(2).to(device)
    xavier_init(previousPotential)
    # previousPotential.load_state_dict(torch.load('outputs/potential/model_potential.pth'))

    model_sbm = ScoreBaseModel(2).to(device)
    sbm_path = 'outputs/pretrain/score_base_model.pth'
    model_sbm.load_state_dict(torch.load(sbm_path))
    # sbm_pathをtxtファイルに保存
    with open(dirname+'/sbm_path.txt', mode='w') as f:
        f.write(sbm_path)

    potential_for_inner_loop = Potential_for_inner_loop(beta = beta, device=device)
    newPotential = potential_for_inner_loop.train_potential(previousPotential, model_sbm, idx_loop, savedirname=dirname, plot_data=True)

    # before_obj, before_dpo, before_kl = objective_class.reglarized_objective(previousPotential, model_sbm, beta)
    after_obj, after_dpo, after_kl = objective_class.reglarized_objective(newPotential, model_sbm, beta)

    print("idx_loop: ", idx_loop)
    # print("before: ", before_obj)
    print("after: ", after_obj)

    print("## test end ##")

if __name__ == "__main__":
    main()