import os
import torch
import time
import random
import pandas as pd
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


SEED = 42
set_seed(SEED)


def load_model_from_config(config, ckpt, device='cuda:1'):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")  # 当gpu内存不够时用cpu加载
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    device = device if torch.cuda.is_available() else 'cpu'
    model.to(device)
    print(model.device)
    model.eval()
    return model


def get_model(config_name, ckpt_path, device='cuda:1'):
    config = OmegaConf.load(os.path.join('configs/latent-diffusion', '{}.yaml'.format(config_name)))  
    model = load_model_from_config(config, ckpt_path, device=device)
    return model


batch_size = 20
ddim_steps = 20
ddim_eta = 0.0
uncondition_guidance_scale = 3.0
device = 'cuda:1'

cond_file = '~/Project/CcGAN-SEM/dataset/124labels.npy'
config_name = '~/tai/sdcopy/configs/latent-diffusion/tai-sem-ldm-vq-f8'
exp_name = 'tai-sem-ldm-vq-f8-10000-805-ddim20test'
if 'sem' in exp_name:
    ckpt_path = 'logs/2025-07-28T00-26-49_tai-sem-ldm-vq-f8/checkpoints/epoch=000049.ckpt'
    # ckpt_path = "logs/2025-05-10T01-43-37_tai-sem-ldm-vq-f8/checkpoints/last.ckpt"
else:
    ckpt_path = '~/tai/sdcopy/logs/2025-07-22T19-22-21_tai-om-ldm-vq-f8/checkpoints/epoch=000039.ckpt'
eval_dir = os.path.join('logs/eval', exp_name)

if not os.access(eval_dir, os.F_OK):
    os.makedirs(eval_dir)

# Load Model
model = get_model(config_name, ckpt_path, device=device)
sampler = DDIMSampler(model, device=device)

# Load Cond
conds = np.load(cond_file)

#总图片数量
total_images = conds.shape[0]  # 获取总图片数量
print(f"总图片数量: {total_images}")

total_generation_time = 0.0

with torch.no_grad():
    with model.ema_scope():
        num_conds = conds.shape[0]
        for bid in tqdm(range((num_conds-1)//batch_size + 1)):
            n_samples = min(batch_size, num_conds - bid*batch_size)
            ys = conds[bid*batch_size:min((bid+1)*batch_size, num_conds)]
            print('Batch {} / {}: '.format(bid+1, (num_conds-1)//batch_size+1))

            batch_start_time = time.time()

            c = model.get_learned_7_conditioning(
                torch.tensor(ys[:, 0]).to(device),
                torch.tensor(ys[:, 1]).to(device),
                torch.tensor(ys[:, 2]).to(device),
                torch.tensor(ys[:, 3]).to(device),
                torch.tensor(ys[:, 4]).to(device),
                torch.tensor(ys[:, 5]).to(device),
                torch.tensor(ys[:, 6]).to(device),
                return_oc = False
            )
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                            conditioning=c,
                                            batch_size=n_samples,
                                            shape=[4, 32, 32],
                                            verbose=False,
                                            unconditional_guidance_scale=uncondition_guidance_scale,
                                            unconditional_conditioning=c, 
                                            eta=ddim_eta)
            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                        min=0.0, max=1.0)
            
            # 计算并累加批次生成时间
            batch_time = time.time() - batch_start_time
            total_generation_time += batch_time
            print(f"批次 {bid+1} 生成时间: {batch_time:.4f} 秒, 平均每张: {batch_time/n_samples:.4f} 秒")


            for i in range(n_samples):
                # display as grid
                grid = x_samples_ddim[i:i+1].unsqueeze(0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=1)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                name = 'c{:>06d}'.format(bid*batch_size + i)
                save_path = os.path.join(eval_dir, 'samples_{}.png'.format(name))
                Image.fromarray(grid.astype(np.uint8)).save(save_path)

avg_time_per_image = total_generation_time / total_images
print(f"\n所有图片生成完成!")
print(f"总生成时间: {total_generation_time:.4f} 秒")
print(f"平均每张图片生成时间: {avg_time_per_image:.4f} 秒")
print(f"每秒可生成: {1/avg_time_per_image:.4f} 张图片")
