# 打通sd生成全流程，从输入性能到输出图像-----保存1w张
import torch
import numpy as np
import random
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
import os
from ldm.models.diffusion.ddim import DDIMSampler
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
import timeit
start = timeit.default_timer()


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


SEED = 42
set_seed(SEED)


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")  # 当gpu内存不够时用cpu加载
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    #model.cuda()
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    print(model.device)
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("configs/latent-diffusion/tai-sem-ldm-vq-f8.yaml")  
    model = load_model_from_config(config, "logs/2025-05-10T01-43-37_tai-sem-ldm-vq-f8/checkpoints/last.ckpt")
    return model


# 加载模型

model = get_model()
model = model.to('cuda:0')
sampler = DDIMSampler(model)

# 生成样本


# mean
# 9.4624e+02
# 1.0423e+03
# 1.2470e+01
# 3.4270e+01
# 3.8528e+01
# 1.4171e+03
# 1.5875e-01
# y1 = 9.4624e+02
# y2 = 1.0423e+03
# y3 = 1.2470e+01
# y4 = 3.4270e+01
# y5 = 3.8528e+01
# y6 = 1.4171e+03
# y7 = 1.5875e-01

# # hkht 1309，1094，10，55，12.5，1709，0.08
# y1 = 1309
# y2 = 1094
# y3 = 10
# y4 = 55
# y5 = 12.5
# y6 = 1709
# y7 = 0.08
# # bqzb  950，850，12.5，32，46，1350，0.235
# y1 = 950
# y2 = 850
# y3 = 12.5
# y4 = 32
# y5 = 46
# y6 = 1350
# y7 = 0.235
# # chuan  900，818，13，30.83，50.67，1300，0.257
# y1 = 900
# y2 = 818
# y3 = 13
# y4 = 30.83
# y5 = 50.67
# y6 = 1300
# y7 = 0.257
# yi  1000，880，8，36.7，41.3，1400，0.213
y1 = 1000
y2 = 880
y3 = 8
y4 = 36.7
y5 = 41.3
y6 = 1400
y7 = 0.213

# classes = [1, 1, 1, 1]   # define classes to be sampled here
n_samples = 10 # 1 # 10000

ddim_steps = 20 # 20
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance


with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_7_conditioning(
            torch.tensor(y1).to(model.device),
            torch.tensor(y2).to(model.device),
            torch.tensor(y3).to(model.device),
            torch.tensor(y4).to(model.device),
            torch.tensor(y5).to(model.device),
            torch.tensor(y6).to(model.device),
            torch.tensor(y7).to(model.device)
            )
        
        for i in range(n_samples):
            all_samples = list()
            # print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
            
            xc1 = torch.tensor(y1).to(model.device)
            xc2 = torch.tensor(y2).to(model.device)
            xc3 = torch.tensor(y3).to(model.device)
            xc4 = torch.tensor(y4).to(model.device)
            xc5 = torch.tensor(y5).to(model.device)
            xc6 = torch.tensor(y6).to(model.device)
            xc7 = torch.tensor(y7).to(model.device)
            print(model.device)
            xc = [xc1,xc2,xc3,xc4,xc5,xc6,xc7]
            c = model.get_learned_7_conditioning(xc[0],xc[1],xc[2],xc[3],xc[4],xc[5],xc[6], return_oc=False)
            
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c.unsqueeze(0),
                                             batch_size=1,
                                             shape=[4, 32, 32],  # [4, 64, 64],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc.unsqueeze(0), 
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)


            # display as grid
            grid = torch.stack(all_samples, 0)
            grid = rearrange(grid, 'n b c h w -> (n b) c h w')
            grid = make_grid(grid, nrow=1)

            # to image
            grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
            import datetime
            now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
            name = "yi_tai-enhance_om_" + str(i)
            # Image.fromarray(grid.astype(np.uint8)).save("./logs/2024-01-07T16-00-52_tai-ldm-vq-f8-om/samples/samples_{}.png".format(i))
            # Image.fromarray(grid.astype(np.uint8)).save("./logs/2024-01-07T16-00-52_tai-ldm-vq-f8-om/samples_{}.png".format(name))
            
            if not os.access('logs/debug', os.F_OK):
                os.makedirs('logs/debug')
            Image.fromarray(grid.astype(np.uint8)).save("./logs/debug/samples_{}.png".format(name))


stop = timeit.default_timer()
print("GAN training finished; Time elapses: {}s".format(stop - start))
# python sample.py