import torch
from diffusers import StableDiffusionPipeline
from losses import adain_loss
from pipeline_sd import ADPipeline
import utils
from PIL import Image
import torchvision.transforms as transforms
import copy
import os

# 设置设备
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# 初始化 Stable Diffusion pipeline
model_id = "CompVis/stable-diffusion-v1-5"
pipe = ADPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None)
pipe = pipe.to(device)
pipe.classifier = pipe.unet
pipe.cache = utils.DataCache()
pipe.controller = utils.Controller(self_layers=(0, 10))

# 冻结版本，用于提取目标特征
frozen_pipe = copy.deepcopy(pipe)
frozen_pipe = frozen_pipe.to(device)
frozen_pipe.classifier = frozen_pipe.unet
frozen_pipe.cache = pipe.cache
frozen_pipe.controller = utils.Controller(self_layers=(0, 10))

# 冻结所有参数
for param in frozen_pipe.classifier.parameters():
    param.requires_grad = False

original_params = {
    name: param.clone().detach()
    for name, param in pipe.classifier.named_parameters()
    if param.requires_grad
}

for name, param in pipe.classifier.named_parameters():
    if "attn1" in name and ("to_q" in name or "to_k" in name or "to_v" in name):
        param.requires_grad = True
    else:
        param.requires_grad = False
# 优化器
optimizer = torch.optim.Adam(
    [param for name, param in pipe.classifier.named_parameters() if "attn1" in name and ("to_q" in name or "to_k" in name or "to_v" in name)],
    lr=1e-4
)

# 注册注意力控制模块
utils.register_attn_control(pipe.classifier, controller=pipe.controller, cache=pipe.cache)
utils.register_attn_control(frozen_pipe.classifier, controller=frozen_pipe.controller, cache=frozen_pipe.cache)

# 预处理 transform
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

# 加载图像路径和文本对（只示例了几个，实际请添加完整的 reference_pairs）
reference_pairs = [
    {
        "image_path": "/content/The-night.jpg",
        "text_prompt": "the starry night by van gogh "
    },
    {
        "image_path": "/content/The-whield-field-with-cypresses.png",
        "text_prompt": "wheat field with cypresses by Van Gogh"
    },
    {
        "image_path": "/content/self-potraite.png",
        "text_prompt": "the self-portraite by Van Gogh"
    },
    {
        "image_path": "/content/Bedroom-in-arles.png",
        "text_prompt": "Bedroom in Arles by Vincent van Gogh"
    },
    {
        "image_path": "/content/Bridge-at=Trinquetaille.png",
        "text_prompt": "Bridge at Trinquetaille by Vincent van Gogh"
    },
    {
        "image_path": "/content/cafe-Terrace-at-night.png",
        "text_prompt": "Café Terrace at Night by Vincent van Gogh"
    },
    {
        "image_path": "/content/Enclosed-Field-with-risingsun.png",
        "text_prompt": "Enclosed Field with Rising Sun by Vincent van Gogh"
    },
    {
        "image_path": "/content/Entranc-to-quarry.png",
        "text_prompt": "Entrance to a Quarry by Vincent van Gogh"
    },
    {
        "image_path": "/content/Fish-boats-on-Beach.png",
        "text_prompt": "Fishing Boats on the Beach at Saintes-Maries by Vincent van Gogh"
    },

    {
        "image_path": "/content/Harvest-at-la-crau.png",
        "text_prompt": "Harvest at La Crau, with Montmajour in the Background by Vincent van Gogh"
    },
    {
        "image_path": "/content/Irise.png",
        "text_prompt": "Irises by Vincent van Gogh"
    },

    {
        "image_path": "/content/Landscape-at-Saint-Remy.png",
        "text_prompt": "Landscape at Saint-Rémy by Vincent van Gogh"
    },
    {
        "image_path": "/content/Landscape-with-Snow.png",
        "text_prompt": "Landscape with Snow by Vincent van Gogh"
    },
    {
        "image_path": "/content/Olive-Trees.png",
        "text_prompt": "Olive Trees by Vincent van Gogh"
    },

    {
        "image_path": "/content/The-Red-Vineyards.png",
        "text_prompt": "Red Vineyards at Arles by Vincent van Gogh"
    },

    {
        "image_path": "/content/Sunflowers.png",
        "text_prompt": "Sunflowers by Vincent van Gogh"
    },
    {
        "image_path": "/content/The-Vhurch-at-Auver.png",
        "text_prompt": "The Church at Auvers by Vincent van Gogh"
    },
    {
        "image_path": "/content/The-Cottage.png",
        "text_prompt": "The Cottage by Vincent van Gogh"
    },
    {
        "image_path": "/content/The-Yellow-House.png",
        "text_prompt": "The Yellow House by Vincent van Gogh"
    },
    {
        "image_path": "/content/Two-Cut-Flowers.png",
        "text_prompt": "Two Cut Sunflowers by Vincent van Gogh"
    },
    {
        "image_path": "/content/painting-of-a-tree.png",
        "text_prompt": "Painting of a tree in the style of Van Gogh"
    },
    {
        "image_path": "/content/painting-of-a-city.png",
        "text_prompt": "A painting of a city in the style of van gogh"
    },
    {
        "image_path": "/content/Almond Blossoms.png",
        "text_prompt": "Almond Blossoms by Vincent van Gogh"
    },
     {
        "image_path": "/content/Green-wheat-Field-withcypress.png",
        "text_prompt": "Green Wheat Field with Cypress by Vincent van Gogh"
    },
    {
        "image_path": "/content/dog.png",
        "text_prompt": "The dog by van gogh style"
    },
    {
        "image_path": "/content/panda.png",
        "text_prompt": "The panda by van gogh style"
    },
    {
        "image_path": "/content/cat.png",
        "text_prompt": "The cat by van gogh style"
    },
    {
        "image_path": "/content/Sorrow.png",
        "text_prompt": "Sorrow by Vincent van Gogh"
    },
    {
        "image_path": "/content/Sower-with-setting-sun.png",
        "text_prompt": "Sower with Setting Sun by Vincent van Gogh"
    },
     {
        "image_path": "/content/The-Old-Mill.png",
        "text_prompt": "The Old Mill by Vincent van Gogh"
    },
    {
        "image_path": "/content/Vase-with-Fifteen-Sunflowers.png",
        "text_prompt": "Vase with Fifteen Sunflowers by Vincent van Gogh"
    },
     {
        "image_path": "/content_/The-Zouave .png",
        "text_prompt": "The Zouave by Vincent van Gogh"
    },
     {
        "image_path": "/content/starrynight.png",
        "text_prompt": "the starry night by van gogh "
    },
    # 添加更多参考图像和文本对
]

# 参考数据加载
reference_data = []
for pair in reference_pairs:
    style_image = Image.open(pair["image_path"]).convert("RGB")
    style_image = transform(style_image).unsqueeze(0).to(device)
    style_latent = pipe.image2latent(style_image)
    null_embeds_for_style = pipe.encode_prompt(pair["text_prompt"], device, 1, False)[0]
    reference_data.append((style_latent, null_embeds_for_style))

# 输入噪声 latent
input_latents = torch.randn((1, 4, 64, 64), device=device)

# 获取空 prompt 的嵌入（用于冻结模型）
null_embeds = pipe.encode_prompt("", device, 1, False)[0]
pipe.null_embeds = null_embeds

# 训练参数
num_epochs = 50
num_steps_per_epoch = 1
reg_lambda = 0.000325

# 训练主循环
for epoch in range(num_epochs):
    for ref_idx, (style_latent, null_embeds_for_style) in enumerate(reference_data):
        for step in range(num_steps_per_epoch):
            # 时间步
            t = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (1,), device=device)

            # 冻结模型提取参考特征
            with torch.no_grad():
                q_ref, k_ref, v_ref, _ = frozen_pipe.extract_feature(
                    style_latent, t, pipe.null_embeds, add_noise=True
                )

            # 可训练模型提取生成图特征
            q_style, k_style, v_style, _ = pipe.extract_feature(
                input_latents, t, null_embeds_for_style, add_noise=False
            )

            # 自注意力对齐损失
            attn_loss = torch.tensor(0.0, requires_grad=True, device=device)
            for q_s, k_s, v_s, q_r, k_r, v_r in zip(q_style, k_style, v_style, q_ref, k_ref, v_ref):
                attn_loss = attn_loss + adain_loss(q_s, k_s, v_s, q_r, k_r, v_r, lambda_q=0.5)

            # 参数正则损失
            reg_loss = torch.tensor(0.0, device=device)
            for name, param in pipe.classifier.named_parameters():
                if param.requires_grad:
                    reg_loss += torch.norm(param - original_params[name].to(device), p=1)

            # 总损失
            total_loss = attn_loss + reg_lambda * reg_loss

            # 反向传播与优化
            optimizer.zero_grad()
            total_loss.backward(retain_graph=True)
            optimizer.step()

            # 日志
            print(f"[Epoch {epoch+1}/{num_epochs}] Ref {ref_idx+1}/{len(reference_data)} Step {step+1}: "
                  f"Loss = {total_loss.item():.4f} (Attn: {attn_loss.item():.4f}, Reg: {reg_loss.item():.4f})")

# 创建保存文件夹
save_folder = "trained_pipeline_folder"
os.makedirs(save_folder, exist_ok=True)

# 使用 save_pretrained 方法保存整个 pipeline
pipe.save_pretrained(save_folder)
