import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms

import matplotlib.pyplot as plt

import base64
from io import BytesIO

import datetime

import os, json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

from torch.autograd.functional import jacobian
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

import sys
sys.path.append("/home/***/work/doob_apps/hug")

from src.utils.img_func import show_images, make_grid, preprocess, transform
from src.finetune.obj import objective_dpo
from src.models.model_potential import ModelPotential

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # sample_ref.pyの生成した画像を表示する
    # 画像のtensorを読み込む

    dirname = "/home/***/work/doob_apps/hug/outputs/obj/tests"
    # 日付を取得
    dt_now = datetime.datetime.now()
    # フォルダ名を作成
    dirname = os.path.join(dirname, dt_now.strftime("%Y%m%d_%H%M%S"))
    # フォルダを作成
    os.makedirs(dirname, exist_ok=True)

    images = torch.load("/home/***/work/doob_apps/hug/outputs/samples/20240910_104356/samples_ref.pth")
    print("image.shape:", images.shape)
    # (batch_size, channel, height, width)になっている
    img = show_images(images[:64])
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/butterfly.png")

    print("sample_ref.pyの生成した画像を表示しました")

    class_obj = objective_dpo()

    images_32 = images[:32].to(device)
    images_32_shuffled = images[torch.randperm(32)].to(device)

    pref = class_obj._preference(images_32, images_32_shuffled)
    print("pref:", pref)

    model_potential = ModelPotential()
    model_potential.to(device)
    # DataParallelを使用して並列化
    if torch.cuda.device_count() > 1:
        model_potential = torch.nn.DataParallel(model_potential)
    # 時間を取得
    now = datetime.datetime.now()
    x = torch.randn(128, 3, 32, 32).to(device)
    # y = model_potential(x)
    # 時間を計測し, 秒で表示
    print("time:", datetime.datetime.now() - now)

    # print("y:", y.shape)

    images_batch = images[:4].to(device)

    potential_batch = class_obj.potential(images_batch, model_potential)
    print("porential_batch:", potential_batch)

    img = show_images(images_batch)
    plt.imshow(img)
    # それぞれのimgに対して, objを表示する
    plt.title("potential:" + str(potential_batch))
    plt.axis("off")
    plt.savefig(dirname + "/butterfly_obj.png")

if __name__ == "__main__":
    main()