
from torchvision import transforms
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import sys
sys.path.append('/home/***/work/doob_apps/hug')
from src.models.CT_model_predictor import RotationPredictorCNN
from src.models.CT_autoencoder import Autoencoder, SmallAutoencoder, Autoencoder32
from src.preference.CT_learn_preference import CTImageDataset, RandomRotationWithLabel

from torch.autograd.functional import jacobian

import os
import wandb
import torch.nn.functional as F

import datetime

def main():
    # デバイス
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    # Initialize the autoencoder
    autoencoder = Autoencoder32()
    # load the pre-trained model
    autoencoder_path = "/home/***/work/doob_apps/hug/src/pretrain/autoencoder/20240918_2031_32_0.2/autoencoder_epoch_19.pth"
    dir = '/home/***/work/doob_apps/hug/src/pretrain/CT_diffusion/20240918_2049/my_pipeline'
    unet_dir = 'hug/outputs/figures/test_upperbound_CT/20241112_150004_cu_1/unet_training_10.pth'

    config = {
        "autoencoder_path": autoencoder_path,
        "diffusion_dir": dir
    }

    autoencoder.load_state_dict(torch.load(autoencoder_path))
    autoencoder.eval()
    autoencoder.to(device)


    # set seed
    random.seed(0)
    torch.manual_seed(0)

    from diffusers import DDPMScheduler, UNet2DModel

    latent_size = 32

    # wandbの初期化
    wandb.init(project='CTImage-diffusion-test', config=config)
    wandb.watch(autoencoder)

    # make a pipeline
    from diffusers import DDPMPipeline

    # DDPMPipelineをload
    image_pipe = DDPMPipeline.from_pretrained(dir)

    # モデルの保存
    unet = image_pipe.unet
    unet.to(device)
    # load from the saved model
    if unet_dir is not None:    
        unet.load_state_dict(torch.load(unet_dir))
        print("load unet from", unet_dir)
    wandb.watch(unet)

    noise_scheduler = image_pipe.scheduler

    # (bs, 1, 32, 32)の画像を何回も生成して, (bs*1000, 1, 32, 32)の画像を生成する
    iteration = 1# 024
    for iter in range(iteration):
        # Random starting point (4 random images):
        bs = 16
        print(f"sampling {bs} images")
        sample = torch.randn(bs, 1, latent_size, latent_size).to(device)
        for i, t in enumerate(noise_scheduler.timesteps):
            # Get model pred
            with torch.no_grad():
                residual = unet(sample, t).sample
            # Update sample with step
            sample = noise_scheduler.step(residual, t, sample).prev_sample
        wandb.log({"iter": iter})
        # decode the image
        sample.view(bs, 1, latent_size, latent_size)
        decoded_image = autoencoder.decoder(sample)
        # tesorとしてstack
        decoded_image = decoded_image.view(bs, 1, 64, 64)
        if iter == 0:
            decoded_images = decoded_image
        else:
            decoded_images = torch.cat([decoded_images, decoded_image], dim=0)
    
    print("decoded_images.shape:", decoded_images.shape)

    # calc jacobian
    # time
    cal_jacobian = False
    if cal_jacobian:
        now = datetime.datetime.now()
        sample = sample[0].unsqueeze(0)
        sample = sample.requires_grad_(True)
        def output_fn(x):
            return unet(x, 999)["sample"]
        j = jacobian(output_fn, sample)
        print("calc jacobian time:", datetime.datetime.now()-now)

    # save dir
    dirname = '/home/***/work/doob_apps/hug/outputs/CT_diffusion'
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M')
    dirname = os.path.join(dirname, now_str)
    os.makedirs(dirname, exist_ok=True)

    # tensorとしてlocalに保存
    torch.save(decoded_images, os.path.join(dirname, "decoded_images.pth"))


    predictor = RotationPredictorCNN()
    predictor.load_state_dict(torch.load("/home/***/work/doob_apps/hug/src/preference/CT_predictor_20240919_1636/rotation_predictor.pth"))
    predictor.eval()
    predictor.to(device)

    loss_avg = 0
    for i in range(bs):
        image_i = sample[i]
        image_i = image_i.view(1, 1, latent_size, latent_size).to(device)
        # decode the image
        decoded_image_i = autoencoder.decoder(image_i).squeeze(0)
        with torch.no_grad():
            pred_loss = torch.abs(predictor(decoded_image_i.unsqueeze(0))).item()
        loss_avg += pred_loss
        wandb.log({"image": wandb.Image(decoded_image_i), "latent image": wandb.Image(image_i)})
        # tensorを画像に変換し, ローカルに保存
        img = transforms.ToPILImage()(decoded_image_i.cpu())
        img_filename = os.path.join(dirname, f"decoded_image_{i}.png")
        img.save(img_filename)
    loss_avg /= bs
    wandb.log({"loss_avg": loss_avg})
    print(f"loss_avg: {loss_avg}")

    wandb.finish()

if __name__ == "__main__":
    main()