import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms

import matplotlib.pyplot as plt

import base64
from io import BytesIO

import datetime

import os, json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

from torch.autograd.functional import jacobian
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

import sys
sys.path.append("/home/***/work/doob_apps/hug")

from src.utils.img_func import show_images, make_grid, preprocess, transform
from src.finetune.obj import objective_dpo
# from src.models.model_potential import ModelPotential
from src.models.CT_model_predictor import RotationPredictorCNN
from src.utils.set_seed import set_seed

def main():
    set_seed(222)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # sample_ref.pyの生成した画像を表示する
    # 画像のtensorを読み込む

    dirname = "/home/***/work/doob_apps/hug/outputs/obj/tests"
    # 日付を取得
    dt_now = datetime.datetime.now()
    # フォルダ名を作成
    dirname = os.path.join(dirname, dt_now.strftime("CT_%Y%m%d_%H%M%S"))
    # フォルダを作成
    os.makedirs(dirname, exist_ok=True)
    mode = "CT"
    if mode == "CT":
        images = torch.load("/home/***/work/doob_apps/hug/outputs/CT_diffusion/20240919_1459/decoded_images.pth")
    elif mode == "butterfly":
        images = torch.load("/home/***/work/doob_apps/hug/outputs/samples/20240910_104356/samples_ref.pth")
    print("image.shape:", images.shape)
    # (batch_size, channel, height, width)になっている
    img = show_images(images[:64], mode=mode)
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/CT.png")

    print("CTのsample_ref.pyの生成した画像を表示しました")

    class_obj = objective_dpo(mode=mode)

    images_32 = images[:32].to(device)
    images_32_shuffled = images[torch.randperm(32)].to(device)

    pref = class_obj._preference(images_32, images_32_shuffled)
    print("pref:", pref)

    model_potential = RotationPredictorCNN()
    model_potential.to(device)
    # DataParallelを使用して並列化
    if torch.cuda.device_count() > 1:
        model_potential = torch.nn.DataParallel(model_potential)
    # 時間を取得
    now = datetime.datetime.now()
    x = torch.randn(128, 1, 64, 64).to(device)
    # y = model_potential(x)
    # 時間を計測し, 秒で表示
    # print("time:", datetime.datetime.now() - now)
    # print("y:", y.shape)

    images_batch = images_32[:64].to(device)

    potential_batch = class_obj.potential(images_batch, model_potential)
    print("porential_batch:", potential_batch)

    # potential_batchの小さい順に, images_batchを並び替え
    potential_batch, indices = torch.sort(potential_batch)
    images_batch = images_batch[indices]

    img = show_images(images_batch, mode=mode)
    plt.imshow(img)
    # それぞれのimgに対して, objを表示する
    plt.title("potential:" + str(potential_batch))
    plt.axis("off")
    plt.savefig(dirname + "/CT_obj.png")
    plt.close()

    ## predictorの検証
    predictor = RotationPredictorCNN().to(device)
    ## モデルの読み込み
    model_path = "/home/***/work/doob_apps/hug/src/preference/CT_predictor_202409151939/rotation_predictor.pth"
    predictor.load_state_dict(torch.load(model_path))
    predictor.eval()
    # images_batchをpredictorに入力
    with torch.no_grad():
        y = predictor(images_batch).detach().squeeze(1)
    # yの順番にimages_batchを並び替え
    y, indices = torch.sort(y)
    images_batch = images_batch[indices]
    img = show_images(images_batch, mode=mode)
    plt.imshow(img)
    plt.title("predicted Angles:" + str(y))
    plt.axis("off")
    plt.savefig(dirname + "/CT_predicted_angles.png")
    plt.close()


if __name__ == "__main__":
    main()