import os
import torch
import time
import random
import numpy as np
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
from torchvision import transforms
from tqdm import tqdm
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


def set_seed(seed):
    """设置    设置随机种子以确保结果可复现
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # 多GPU情况下使用
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# 设置随机种子
SEED = 42
set_seed(SEED)


def load_model_from_config(config, ckpt, device='cuda:1'):
    """
    从配置和检查点加载模型
    """
    print(f"从 {ckpt} 加载模型")
    pl_sd = torch.load(ckpt, map_location="cpu")  # 先在CPU加载以节省GPU内存
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    device = device if torch.cuda.is_available() else 'cpu'
    model.to(device)
    print(f"模型设备: {model.device}")
    model.eval()
    return model


def get_model(config_path, ckpt_path, device='cuda:1'):
    """
    获取配置并加载模型
    """
    config = OmegaConf.load(config_path)
    model = load_model_from_config(config, ckpt_path, device=device)
    return model


def load_image_conditions(cond_image_dir, image_size=256):
    """
    加载条件图片并进行预处理
    """
    # 定义图片预处理转换
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # 获取所有条件图片路径
    image_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
    cond_image_paths = []
    for filename in os.listdir(cond_image_dir):
        if any(filename.lower().endswith(ext) for ext in image_extensions):
            cond_image_paths.append(os.path.join(cond_image_dir, filename))
    
    # 排序确保顺序一致
    cond_image_paths.sort()
    
    # 加载并预处理图片
    cond_images = []
    for img_path in cond_image_paths:
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img)
        cond_images.append(img_tensor)
    
    return cond_images, cond_image_paths


# 配置参数
batch_size = 20
ddim_steps = 1
ddim_eta = 0.0
uncondition_guidance_scale = 3.0
device = 'cuda:1'  # 根据你的GPU配置调整

# 路径配置
config_path = '~/tai/sdcopy/configs/latent-diffusion/tai-sem-ldm-vp-f8-new.yaml'  # 配置文件路径
exp_name = 'tai-ldm-sem-vq-f8-eval-twostage-ddim1test'  # 实验名称
cond_image_dir = '~/tai/sdcopy/logs/eval/tai-sem-ldm-vq-f8-10000-805-ddim1test'  # 条件图片目录

# 根据实验名称选择检查点路径
if 'sem' in exp_name:
    ckpt_path = '~/tai/sdcopy/logs/2025-07-28T00-35-10_tai-sem-ldm-vp-f8-new/checkpoints/epoch=000043.ckpt'
else:
    ckpt_path = '~/tai/sdcopy/logs/2025-07-26T03-04-23_tai-om-ldm-vp-f8-new/checkpoints/epoch=000029.ckpt'

# 创建输出目录
eval_dir = os.path.join('logs/eval', exp_name)
os.makedirs(eval_dir, exist_ok=True)

# 加载模型
print("加载模型中...")
model = get_model(config_path, ckpt_path, device=device)
sampler = DDIMSampler(model, device=device)

# 加载条件图片
print("加载条件图片中...")
cond_images, cond_image_paths = load_image_conditions(cond_image_dir)
num_conds = len(cond_images)
print(f"共加载 {num_conds} 张条件图片")

total_generation_time = 0.0

# 生成图像
with torch.no_grad():
    with model.ema_scope():  # 使用指数移动平均参数
        # 分批处理条件图片
        for bid in tqdm(range((num_conds - 1) // batch_size + 1), desc="生成批次"):
            # 计算当前批次的样本数量
            start_idx = bid * batch_size
            end_idx = min((bid + 1) * batch_size, num_conds)
            n_samples = end_idx - start_idx
            
            if n_samples <= 0:
                continue
                
            print(f'批次 {bid + 1}/{(num_conds - 1) // batch_size + 1}，样本数: {n_samples}')

            batch_start_time = time.time()

            
            # 准备当前批次的条件图片
            batch_cond_images = cond_images[start_idx:end_idx]
            batch_cond_tensor = torch.stack(batch_cond_images).to(device)
            
            # 获取条件特征（使用CLIP模型编码图片）
            c = model.get_learned_conditioning(batch_cond_tensor)
            
            # 使用DDIM采样器生成图像
            samples_ddim, _ = sampler.sample(
                S=ddim_steps,
                conditioning=c,
                batch_size=n_samples,
                shape=[4, 32, 32],  # 潜在空间的形状
                verbose=False,
                unconditional_guidance_scale=uncondition_guidance_scale,
                unconditional_conditioning=None,  # 无条件生成时使用
                eta=ddim_eta
            )
            
            # 将潜在表示解码为图像
            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

            # 计算并累加批次生成时间
            batch_time = time.time() - batch_start_time
            total_generation_time += batch_time
            print(f"批次 {bid+1} 生成时间: {batch_time:.4f} 秒, 平均每张: {batch_time/n_samples:.4f} 秒")


            # 保存生成的图像
            for i in range(n_samples):
                # 转换为图像格式
                img_tensor = x_samples_ddim[i]
                img_np = 255. * rearrange(img_tensor.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(img_np.astype(np.uint8))
                
                # 生成保存路径
                base_name = os.path.splitext(os.path.basename(cond_image_paths[start_idx + i]))[0]
                save_path = os.path.join(eval_dir, f'sample_{base_name}.png')
                img.save(save_path)

print("图像生成完成，结果保存在:", eval_dir)
avg_time_per_image = total_generation_time / num_conds
print(f"\n所有图片生成完成!")
print(f"总生成时间: {total_generation_time:.4f} 秒")
print(f"平均每张图片生成时间: {avg_time_per_image:.4f} 秒")
print(f"每秒可生成: {1/avg_time_per_image:.4f} 张图片")

# import os
# import torch
# import random
# import numpy as np
# from PIL import Image
# from einops import rearrange
# from torchvision.utils import make_grid
# from torchvision import transforms
# from tqdm import tqdm
# from omegaconf import OmegaConf
# from ldm.util import instantiate_from_config
# from ldm.models.diffusion.ddim import DDIMSampler


# def set_seed(seed):
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(seed)
#         torch.cuda.manual_seed_all(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False


# SEED = 42
# set_seed(SEED)


# def load_model_from_config(config, ckpt, device='cuda:1'):
#     print(f"Loading model from {ckpt}")
#     pl_sd = torch.load(ckpt, map_location="cpu")
#     sd = pl_sd["state_dict"]
#     model = instantiate_from_config(config.model)
#     m, u = model.load_state_dict(sd, strict=False)
#     device = device if torch.cuda.is_available() else 'cpu'
#     model.to(device)
#     print(model.device)
#     model.eval()
#     return model


# def get_model(config_path, ckpt_path, device='cuda:1'):
#     config = OmegaConf.load(config_path)
#     model = load_model_from_config(config, ckpt_path, device=device)
#     return model


# # 图片预处理（保持和训练时一致）
# def preprocess_image(image_path, size=256):
#     transform = transforms.Compose([
#         transforms.Resize((size, size)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
#     ])
#     img = Image.open(image_path).convert('RGB')
#     return transform(img)


# batch_size = 20
# ddim_steps = 20
# ddim_eta = 0.0
# uncondition_guidance_scale = 3.0
# device = 'cuda:1'

# # 路径配置（和原代码格式保持一致）
# config_path = '~/tai/sdcopy/configs/latent-diffusion/tai-image-ldm-vq-f8.yaml'  # 替换为你的图片条件配置文件
# exp_name = 'tai-image-ldm-vq-f8-intra-two-10000-729'  # 保持命名风格
# if 'image' in exp_name:
#     ckpt_path = 'logs/2025-07-29T00-26-49_tai-image-ldm-vq-f8/checkpoints/epoch=000049.ckpt'  # 替换为你的图片条件模型
# else:
#     ckpt_path = '~/tai/sdcopy/logs/2025-07-22T19-22-21_tai-om-ldm-vq-f8/checkpoints/epoch=000039.ckpt'
# eval_dir = os.path.join('logs/eval', exp_name)

# # 条件图片路径文件（类似原代码的cond_file，用txt存储图片路径列表）
# cond_image_list = '~/tai/sd/metric/full_image_conds_10000.txt'  # 每行一个图片路径

# if not os.access(eval_dir, os.F_OK):
#     os.makedirs(eval_dir)

# # Load Model
# model = get_model(config_path, ckpt_path, device=device)
# sampler = DDIMSampler(model, device=device)

# # Load Cond（保持和原代码一样的批量加载逻辑）
# with open(cond_image_list, 'r') as f:
#     cond_image_paths = [line.strip() for line in f.readlines() if line.strip()]
# num_conds = len(cond_image_paths)
# print(f"Total condition images: {num_conds}")

# with torch.no_grad():
#     with model.ema_scope():
#         # 保持和原代码一致的批次循环和进度显示
#         for bid in tqdm(range((num_conds-1)//batch_size + 1), desc="Total progress"):
#             n_samples = min(batch_size, num_conds - bid*batch_size)
#             start_idx = bid * batch_size
#             end_idx = start_idx + n_samples
#             batch_paths = cond_image_paths[start_idx:end_idx]
            
#             print('Batch {} / {}: '.format(bid+1, (num_conds-1)//batch_size + 1))
            
#             # 加载并预处理当前批次的条件图片（替代原代码的向量加载）
#             batch_conds = []
#             for path in batch_paths:
#                 img_tensor = preprocess_image(path)
#                 batch_conds.append(img_tensor)
#             batch_conds = torch.stack(batch_conds).to(device)
            
#             # 条件编码（对应原代码的get_learned_7_conditioning）
#             c = model.get_learned_conditioning(batch_conds)
            
#             # 采样过程（完全保持和原代码一致）
#             samples_ddim, _ = sampler.sample(S=ddim_steps,
#                                             conditioning=c,
#                                             batch_size=n_samples,
#                                             shape=[4, 32, 32],
#                                             verbose=False,
#                                             unconditional_guidance_scale=uncondition_guidance_scale,
#                                             unconditional_conditioning=None,  # 图片条件下通常不使用自身作为无条件引导
#                                             eta=ddim_eta)
            
#             # 解码和后处理（保持一致）
#             x_samples_ddim = model.decode_first_stage(samples_ddim)
#             x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
#                                         min=0.0, max=1.0)
            
#             # 保存图片（保持和原代码一样的命名格式）
#             for i in range(n_samples):
#                 grid = x_samples_ddim[i:i+1].unsqueeze(0)
#                 grid = rearrange(grid, 'n b c h w -> (n b) c h w')
#                 grid = make_grid(grid, nrow=1)

#                 grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
#                 name = 'c{:>06d}'.format(bid*batch_size + i)  # 保持原代码的编号风格
#                 save_path = os.path.join(eval_dir, 'samples_{}.png'.format(name))
#                 Image.fromarray(grid.astype(np.uint8)).save(save_path)