from torchvision import transforms
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import sys
sys.path.append('/home/***/work/doob_apps/hug')
from src.models.CT_model_predictor import RotationPredictorCNN
from src.preference.CT_learn_preference import RandomRotationWithLabel
from src.utils.img_func import show_images

import os
import torch
import matplotlib.pyplot as plt
import wandb
import datetime

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dirname = "/home/***/work/doob_apps/hug/outputs/figures/test_predictor"
    os.makedirs(dirname, exist_ok=True)
    # 日付を取得
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = "/home/***/work/doob_apps/hug/outputs/figures/test_predictor/CT_"+now_str
    os.makedirs(dirname, exist_ok=True)
    # モデルの読み込み
    model = RotationPredictorCNN()
    model.load_state_dict(torch.load("/home/***/work/doob_apps/hug/src/preference/CT_predictor_20240919_1636/rotation_predictor.pth"))
    model.eval()
    model.to(device)

    wandb.init(project='CTImage-diffusion-test-predictor')

    images = torch.load("/home/***/work/doob_apps/hug/outputs/CT_diffusion/20240919_1459/decoded_images.pth")
    batch_size = 64
    images_batch = images[:batch_size].to(device)
    # それぞれのimgに対して, objを表示する
    # newpot_batch = newPotential(images_batch)
    # minibatchごとにnewPotentialを計算し, torch.catで結合
    print(images_batch.shape)
    with torch.no_grad():
        newpot_batch = torch.abs(model(images_batch)).squeeze()
    print(newpot_batch.shape)
    # newpot_batchの小さい順に並び替え
    newpot_batch, indices = torch.sort(newpot_batch)
    images_batch = images_batch[indices]
    img = show_images(images_batch, mode = 'CT')
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/ct_predictor.png")
    # wandb.log({"ct_"+str(idx_loop): [wandb.Image(img)]})
    try:
        wandb.log({"ct_predictor" : [wandb.Image(images_batch[i], caption=f'preference: {newpot_batch[i]}') for i in range(images_batch.size(0))]})
    except Exception as e:
        print("wandb.log failed")

if __name__ == "__main__":
    main()