from torchvision import transforms
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import sys
sys.path.append('/home/***/work/doob_apps/hug')
from src.models.CT_model_predictor import RotationPredictorCNN
from src.preference.CT_learn_preference import RandomRotationWithLabel
from src.utils.img_func import show_images
from src.finetune.obj import objective_dpo

import os
import torch
import matplotlib.pyplot as plt
import wandb
import datetime

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dirname = "/home/***/work/doob_apps/hug/outputs/figures/test_target_butterfly"
    os.makedirs(dirname, exist_ok=True)
    # 日付を取得
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = dirname + "/bt_"+now_str
    os.makedirs(dirname, exist_ok=True)

    # define color loss
    class_obj = objective_dpo(device=device, mode="butterfly")
    color_loss = class_obj._color_loss

    wandb.init(project='butterfly-diffusion-test-target')

    images = torch.load("/home/***/work/doob_apps/hug/outputs/samples/20240910_104356/samples_ref.pth")
    batch_size = 64
    images_batch = images[:batch_size].to(device)
    # それぞれのimgに対して, objを表示する
    # newpot_batch = newPotential(images_batch)
    # minibatchごとにnewPotentialを計算し, torch.catで結合
    print(images_batch.shape)
    with torch.no_grad():
        newpot_batch = color_loss(images_batch).squeeze()
    print(newpot_batch.shape)
    # newpot_batchの小さい順に並び替え
    newpot_batch, indices = torch.sort(newpot_batch)
    images_batch = images_batch[indices]
    img = show_images(images_batch, mode = 'butterfly')
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/butterfly_loss.png")
    # wandb.log({"ct_"+str(idx_loop): [wandb.Image(img)]})
    try:
        wandb.log({"butterfly_loss" : [wandb.Image(images_batch[i], caption=f'preference: {newpot_batch[i]}') for i in range(images_batch.size(0))]})
    except Exception as e:
        print("wandb.log failed")

if __name__ == "__main__":
    main()