import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

import sys
sys.path.append("/home/***/work/doob_apps/hug")

from src.models.model_potential import ModelPotential
from src.finetune.obj import objective_dpo
from src.finetune.inner_loop import Inner_Loop
from src.utils.set_seed import set_seed
from src.utils.img_func import show_images
from src.models.CT_model_predictor import RotationPredictorCNN

import matplotlib.pyplot as plt
import wandb

def main():
    print("## testing inner loop ##")

    set_seed(222)
    device = "cuda"
    idx_loop = 1 # 2

    config_path = '/home/***/work/doob_apps/hug/configs/configs_CT.json'
    with open(config_path) as f:
        config = json.load(f)

    config_reg_path = '/home/***/work/doob_apps/hug/configs/regularization_CT.json'
    with open(config_reg_path) as f:
        config_reg = json.load(f)
    beta = config_reg["beta"]
    gamma = config_reg["gamma"]
    y_min = config_reg["y_min"]
    y_max = config_reg["y_max"]
    pot_min = config_reg["pot_min"]
    pot_max = config_reg["pot_max"]
    print("idx_loop: ", idx_loop, ", beta: ", beta)  

    # configにconfig_regを追加
    config.update(config_reg)

    wandb.init(project='CTImage-diffusion-test-inner-loop', config=config)

    # セーブ用のディレクトリを作成
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = "/home/***/work/doob_apps/hug/outputs/figures/test_inner_loop/CT_"+now_str
    os.makedirs(dirname, exist_ok=True)

    objective_class = objective_dpo(device=device, mode="CT") 

    # idx_loopとbetaをtxtファイルに保存
    with open(dirname+'/idx_loop_beta.txt', mode='w') as f:
        f.write("idx_loop: "+str(idx_loop)+"\n")
        f.write("beta: "+str(beta)+"\n")
        f.write("gamma: "+str(gamma)+"\n")
        f.write("y_min: "+str(y_min)+"\n")
        f.write("y_max: "+str(y_max)+"\n")
        f.write("pot_min: "+str(pot_min)+"\n")
        f.write("pot_max: "+str(pot_max)+"\n")

    # previousPotential = ApproxPotential(2).to(device)

    previousPotential = RotationPredictorCNN().to(device)
    # DataParallelを使用して並列化
    if torch.cuda.device_count() > 1:
        previousPotential = torch.nn.DataParallel(previousPotential)

    potential_for_inner_loop = Inner_Loop(beta=beta, device=device, mode="CT")
    newPotential = potential_for_inner_loop.train_potential(previousPotential, idx_loop, savedirname=dirname, plot_data=True)

    # before_obj, before_dpo, before_kl = objective_class.reglarized_objective(previousPotential, model_sbm, beta)
    after_obj, after_dpo, after_kl = objective_class.reglarized_objective(newPotential, beta)

    print("idx_loop: ", idx_loop)
    # print("before: ", before_obj)
    print("after: ", after_obj)

    images = torch.load("/home/***/work/doob_apps/hug/outputs/CT_diffusion/20240919_1459/decoded_images.pth")
    images_batch = images[:4].to(device)
    img = show_images(images_batch)
    plt.imshow(img)
    # それぞれのimgに対して, objを表示する
    newpot_batch = newPotential(images_batch)
    print("newpot_batch:", newpot_batch)
    plt.title("potential:" + str(newpot_batch))
    plt.axis("off")
    plt.savefig(dirname + "/CT_obj.png")

    print("## test end ##")

if __name__ == "__main__":
    main()