import torch
import requests
from PIL import Image
from io import BytesIO
from matplotlib import pyplot as plt
from transformers import AutoTokenizer, PretrainedConfig, CLIPTokenizer, CLIPImageProcessor, CLIPTextModel
import torchvision.transforms as T
from diffusers import (
    AutoencoderKL,
    StableDiffusionPipeline, 
    UNet2DConditionModel,
    DPMSolverMultistepScheduler, 
    AutoPipelineForImage2Image,
    UniPCMultistepScheduler,
    DDPMScheduler,
    DDIMScheduler,
    DPMSolverSDEScheduler,
    ) 
import random      
from diffusers.utils import load_image
import os
import numpy as np
from torchvision import transforms
from CVJL import walk_dir, combine_image
from modules import ParaCondModel, PWTT

VALID_IMAGES_HE = [
    'some/ori.png'
    ]
VALID_IMAGES = [
    'some/roi.png'
    ]
VALID_IMAGES_E = [
    'some/ori_HEW.png'
    ]

VALID_PROMPTS = [
    'hematoxylin and eosin breast histopathology microscopic image',
    ]

def make_test_data(path, plist):
    out = []
    for i in plist:
        out.append(os.path.join(path, i.split('/')[-1]))
    return out

print('total images: ', str(len(VALID_IMAGES)))

device = 'cuda:0'
base_model_path = '/workspace/data2/BRACS_region_control/path_img_base_512patch'

unet = UNet2DConditionModel.from_pretrained(base_model_path+"/checkpoint-12000/unet", torch_dtype=torch.bfloat16)
text_encoder = CLIPTextModel.from_pretrained(base_model_path+"/text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKL.from_pretrained(base_model_path + "/vae", torch_dtype=torch.bfloat16)

pipe = PWTT(
    base_model_path, 
    torch_dtype=torch.bfloat16,
    unet=unet, 
    text_encoder=text_encoder, 
    vae=vae, 
    
).to(device)

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

for p_img_pth, p_Himg_pth, p_prompt in zip(
    VALID_IMAGES, VALID_IMAGES_E, VALID_PROMPTS,):
    
    # height,width = 512,512
    cond_img = Image.open(p_img_pth).convert('RGB')
    cond_Himg = Image.open(p_Himg_pth).convert('RGB')
    cond_img_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            # transforms.Normalize([0.5], [0.5]),
        ])
    cond_img = cond_img_transforms(cond_img).unsqueeze(0)
    gs = 6.5
    print(cond_scale, gs)
    print(p_prompt)
    for i in range(1):
        image = pipe(
            p_prompt, 
            num_inference_steps=25, 
            image = cond_img,
            guidance_scale= gs,
            large_latent_height=1024*2,
            large_latent_width=1024*2,
            # generator = torch.manual_seed(2300) 
            ).images[0]

        image.save(save_path)

