import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn.functional as F
import wandb
import copy
import random
import os
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image

# ==========================================
# 1. 配置与超参数
# ==========================================
CONFIG = {
    "project_name": "gan-experiment-compare-run-2151",
    "z_dim": 100,
    "batch_size": 64,
    "lr": 0.0002,
    "epochs_pretrain": 10,
    "epochs_per_gen": 1,
    "total_generations": 60,
    "samples_to_generate": 6000,
    "device": "cuda" if torch.cuda.is_available() else "cpu",  # <--- 已经在使用GPU
    "seeds": [2025, 2026, 2027],
    "comparison_interval": 10,
    "n_comparison_samples": 16,
    "num_workers": 2 if torch.cuda.is_available() else 0,  # <--- 新增: 多线程数据加载
    "pin_memory": torch.cuda.is_available(),  # <--- 新增: 固定内存以加速GPU传输
}

# <--- 新增: 显示GPU信息
if torch.cuda.is_available():
    print(f"Running on GPU: {torch.cuda.get_device_name(0)}")
    print(
        f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
    )
    print(f"Number of GPUs: {torch.cuda.device_count()}")
else:
    print("Running on CPU")

print(f"Device: {CONFIG['device']}")

# <--- 新增: 设置cudnn优化
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True  # 在固定输入大小时加速

# [模型定义保持不变，已经会在GPU上运行]


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(CONFIG["z_dim"], 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 784),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


class Oracle(nn.Module):
    def __init__(self):
        super(Oracle, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# ==========================================
# 3. 工具函数
# ==========================================


def set_seed(seed):
    """固定所有随机种子"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # <--- 修改: 训练时设置deterministic，但推理时可以关闭
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"--- Seed set to {seed} ---")


def train_oracle_model(train_loader):
    """预训练 Oracle (只运行一次) - GPU加速"""
    print("--- Pre-training Oracle (Judge) ---")
    model = Oracle().to(CONFIG["device"])  # <--- 已经在GPU上
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.NLLLoss()

    model.train()
    for epoch in range(3):
        for data, target in train_loader:
            data, target = data.to(CONFIG["device"]), target.to(
                CONFIG["device"]
            )  # <--- GPU传输
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    print("--- Oracle Ready ---")
    model.eval()
    return model


def evaluate_generation(generator, oracle, num_samples=1000):
    """评估生成质量 - GPU加速"""
    generator.eval()
    oracle.eval()

    # <--- 在GPU上生成噪声
    z = torch.randn(num_samples, CONFIG["z_dim"], device=CONFIG["device"])

    with torch.no_grad():
        gen_imgs = generator(z)  # GPU计算
        oracle_out = oracle(gen_imgs)  # GPU计算
        probs = torch.exp(oracle_out)

        max_probs, pred_labels = torch.max(probs, dim=1)
        avg_confidence = torch.mean(max_probs).item()

        # <--- 只在需要时将数据移到CPU
        preds = pred_labels.cpu().numpy()
        counts = np.bincount(preds, minlength=10)
        p = counts / np.sum(counts)
        p = p[p > 0]
        entropy = -np.sum(p * np.log(p))

    return entropy, avg_confidence, gen_imgs[:32]


def generate_comparison_images(generators_dict, generation, seed, fixed_z=None):
    """生成对比图片 - GPU加速"""
    if fixed_z is None:
        # <--- 在GPU上生成固定噪声
        fixed_z = torch.randn(
            CONFIG["n_comparison_samples"], CONFIG["z_dim"], device=CONFIG["device"]
        )

    comparison_images = {}

    # 为每个模型生成图片（GPU上）
    for name, generator in generators_dict.items():
        generator.eval()
        with torch.no_grad():
            gen_imgs = generator(fixed_z)  # GPU计算
            comparison_images[name] = gen_imgs

    # 创建对比图
    fig, axes = plt.subplots(2, CONFIG["n_comparison_samples"] // 4, figsize=(12, 6))
    fig.suptitle(f"Generation {generation} Comparison (Seed: {seed})", fontsize=16)

    # 绘制lambda=0的图片（第一行）
    for i in range(CONFIG["n_comparison_samples"] // 4):
        for j in range(4):
            idx = i * 4 + j
            if idx < CONFIG["n_comparison_samples"]:
                # <--- 只在绘图时移到CPU
                img = comparison_images["lambda_0"][idx].cpu().numpy()
                img = (img[0] + 1) / 2
                axes[0, i].imshow(img, cmap="gray")
                axes[0, i].axis("off")
                if i == 0:
                    axes[0, i].set_title("λ=0", fontsize=12)

    # 绘制lambda=0.2的图片（第二行）
    for i in range(CONFIG["n_comparison_samples"] // 4):
        for j in range(4):
            idx = i * 4 + j
            if idx < CONFIG["n_comparison_samples"]:
                # <--- 只在绘图时移到CPU
                img = comparison_images["lambda_0.2"][idx].cpu().numpy()
                img = (img[0] + 1) / 2
                axes[1, i].imshow(img, cmap="gray")
                axes[1, i].axis("off")
                if i == 0:
                    axes[1, i].set_title("λ=0.2", fontsize=12)

    plt.tight_layout()

    # 保存本地文件
    save_dir = f"comparisons/seed_{seed}"
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"gen_{generation}_comparison.png")
    plt.savefig(save_path, dpi=150, bbox_inches="tight")

    # 显示图片
    plt.show()
    plt.close()  # <--- 新增: 释放内存

    # 创建网格图片用于wandb（GPU上）
    all_images = torch.cat(
        [comparison_images["lambda_0"], comparison_images["lambda_0.2"]], dim=0
    )
    grid_img = make_grid(
        all_images, nrow=CONFIG["n_comparison_samples"] // 4, normalize=True
    )

    return grid_img, save_path


def pretrain_base_gan(train_loader):
    """预训练基础GAN - GPU加速"""
    print(
        f"--- Pre-training Base GAN (Gen 0) for {CONFIG['epochs_pretrain']} epochs ---"
    )

    generator = Generator().to(CONFIG["device"])  # GPU
    discriminator = Discriminator().to(CONFIG["device"])  # GPU

    g_optimizer = optim.Adam(
        generator.parameters(), lr=CONFIG["lr"], betas=(0.5, 0.999)
    )
    d_optimizer = optim.Adam(
        discriminator.parameters(), lr=CONFIG["lr"], betas=(0.5, 0.999)
    )
    adversarial_loss = nn.BCELoss()

    generator.train()
    discriminator.train()

    for epoch in range(CONFIG["epochs_pretrain"]):
        for i, (imgs, _) in enumerate(train_loader):
            batch_size = imgs.size(0)
            real_imgs = imgs.to(CONFIG["device"])  # GPU传输

            # <--- 在GPU上创建标签
            valid = torch.ones(batch_size, 1, device=CONFIG["device"])
            fake = torch.zeros(batch_size, 1, device=CONFIG["device"])

            # Train D
            d_optimizer.zero_grad()
            z = torch.randn(batch_size, CONFIG["z_dim"], device=CONFIG["device"])  # GPU
            gen_imgs = generator(z)  # GPU计算
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            d_optimizer.step()

            # Train G
            g_optimizer.zero_grad()
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            g_optimizer.step()

    print("--- Base GAN Ready ---")
    return copy.deepcopy(generator.state_dict()), copy.deepcopy(
        discriminator.state_dict()
    )


def run_recursive_loop(
    lambda_val,
    seed,
    init_g_state,
    init_d_state,
    oracle_model,
    original_loader,
    original_dataset,
):
    """递归训练循环 - GPU加速"""

    group_name = f"Group_Lambda_{lambda_val}"
    run_name = f"L_{lambda_val}_Seed_{seed}"

    wandb.init(
        project=CONFIG["project_name"],
        group=group_name,
        name=run_name,
        config={**CONFIG, "lambda": lambda_val, "seed": seed},
        reinit=True,
    )

    print(f"\n=== Starting Recursive Loop: {run_name} ===")

    # 1. 加载模型到GPU
    generator = Generator().to(CONFIG["device"])
    discriminator = Discriminator().to(CONFIG["device"])
    generator.load_state_dict(init_g_state)
    discriminator.load_state_dict(init_d_state)

    g_optimizer = optim.Adam(
        generator.parameters(), lr=CONFIG["lr"], betas=(0.5, 0.999)
    )
    d_optimizer = optim.Adam(
        discriminator.parameters(), lr=CONFIG["lr"], betas=(0.5, 0.999)
    )
    adversarial_loss = nn.BCELoss()

    # 2. 准备真实数据迭代器 (优化的数据加载)
    real_loader_for_mix = DataLoader(
        original_dataset,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=CONFIG["num_workers"],  # <--- 多线程
        pin_memory=CONFIG["pin_memory"],  # <--- 固定内存
    )
    real_data_iter = iter(real_loader_for_mix)

    # 3. 生成初始数据
    print("Generating dataset for Generation 1 from Base Model...")
    generator.eval()
    with torch.no_grad():
        z_next = torch.randn(
            CONFIG["samples_to_generate"], CONFIG["z_dim"], device=CONFIG["device"]
        )
        current_gen_data = generator(z_next).cpu()  # <--- 只在存储时移到CPU

    current_dataset = TensorDataset(
        current_gen_data, torch.zeros(CONFIG["samples_to_generate"])
    )

    # 记录Gen 0
    entropy, confidence, viz_imgs = evaluate_generation(generator, oracle_model)
    grid_img = make_grid(viz_imgs, nrow=8, normalize=True)
    wandb.log(
        {
            "generation": 0,
            "metrics/entropy": entropy,
            "metrics/confidence": confidence,
            "examples": wandb.Image(grid_img, caption="Gen 0 (Base)"),
        }
    )

    # <--- 新增: 固定用于对比的噪声向量
    fixed_z_for_comparison = torch.randn(
        CONFIG["n_comparison_samples"], CONFIG["z_dim"], device=CONFIG["device"]
    )

    # 4. 递归训练
    for gen in range(1, CONFIG["total_generations"] + 1):

        # <--- 优化的数据加载
        train_loader = DataLoader(
            current_dataset,
            batch_size=CONFIG["batch_size"],
            shuffle=True,
            drop_last=True,
            num_workers=CONFIG["num_workers"],
            pin_memory=CONFIG["pin_memory"],
        )

        g_loss_epoch = []
        d_loss_epoch = []

        generator.train()
        discriminator.train()

        for epoch in range(CONFIG["epochs_per_gen"]):
            for i, (imgs, _) in enumerate(train_loader):
                batch_size = imgs.size(0)

                # <--- GPU上创建标签
                valid = torch.ones(batch_size, 1, device=CONFIG["device"])
                fake = torch.zeros(batch_size, 1, device=CONFIG["device"])

                real_imgs = imgs.to(CONFIG["device"])  # GPU传输

                # 混合真实数据
                if lambda_val > 0:
                    try:
                        orig_imgs, _ = next(real_data_iter)
                    except StopIteration:
                        real_data_iter = iter(real_loader_for_mix)
                        orig_imgs, _ = next(real_data_iter)

                    orig_imgs = orig_imgs.to(CONFIG["device"])
                    num_replace = int(batch_size * lambda_val)
                    if num_replace > 0:
                        real_imgs[:num_replace] = orig_imgs[:num_replace]

                # Train D (GPU计算)
                d_optimizer.zero_grad()
                z = torch.randn(batch_size, CONFIG["z_dim"], device=CONFIG["device"])
                gen_imgs = generator(z)

                real_loss = adversarial_loss(discriminator(real_imgs), valid)
                fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                d_optimizer.step()

                # Train G (GPU计算)
                g_optimizer.zero_grad()
                g_loss = adversarial_loss(discriminator(gen_imgs), valid)
                g_loss.backward()
                g_optimizer.step()

                g_loss_epoch.append(g_loss.item())
                d_loss_epoch.append(d_loss.item())

        # 评估
        avg_g_loss = np.mean(g_loss_epoch)
        avg_d_loss = np.mean(d_loss_epoch)

        entropy, confidence, viz_imgs = evaluate_generation(generator, oracle_model)

        print(f"[Gen {gen}] Entropy: {entropy:.3f} | Confidence: {confidence:.3f}")

        grid_img = make_grid(viz_imgs, nrow=8, normalize=True)

        wandb.log(
            {
                "generation": gen,
                "metrics/entropy": entropy,
                "metrics/confidence": confidence,
                "loss/generator": avg_g_loss,
                "loss/discriminator": avg_d_loss,
                "examples": wandb.Image(grid_img, caption=f"Gen {gen}"),
            }
        )

        # <--- 新增: 每10代生成对比图片
        if gen % CONFIG["comparison_interval"] == 0:
            # 保存当前生成器状态用于后续对比
            torch.save(
                {
                    "generation": gen,
                    "generator_state": generator.state_dict(),
                    "lambda": lambda_val,
                    "seed": seed,
                },
                f"checkpoints/gen_{gen}_lambda_{lambda_val}_seed_{seed}.pt",
            )

        # 准备下一代数据
        generator.eval()
        with torch.no_grad():
            z_next = torch.randn(
                CONFIG["samples_to_generate"], CONFIG["z_dim"], device=CONFIG["device"]
            )
            next_gen_data = generator(z_next).cpu()  # <--- 只在存储时移到CPU

        current_dataset = TensorDataset(
            next_gen_data, torch.zeros(CONFIG["samples_to_generate"])
        )

    wandb.finish()
    return generator  # <--- 返回最终的生成器用于对比


# ==========================================
# 5. 主程序入口
# ==========================================

if __name__ == "__main__":
    # 创建必要的目录
    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("comparisons", exist_ok=True)

    # 1. 准备原始数据
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    )

    # <--- 优化的数据集加载
    original_dataset = datasets.MNIST(
        "./data", train=True, download=True, transform=transform
    )
    original_loader = DataLoader(
        original_dataset,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=CONFIG["num_workers"],
        pin_memory=CONFIG["pin_memory"],
    )

    # 2. 预训练 Oracle
    set_seed(42)
    oracle_model = train_oracle_model(original_loader)

    # 3. 实验主循环
    lambda_settings = [0, 0.2]

    for seed in CONFIG["seeds"]:
        print(f"\n################################################")
        print(f"### Processing Seed: {seed} ###")
        print(f"################################################\n")

        set_seed(seed)

        # 预训练基础GAN
        base_g_state, base_d_state = pretrain_base_gan(original_loader)

        # 存储每个lambda实验的最终生成器
        final_generators = {}

        # 运行各lambda实验
        for lambda_val in lambda_settings:
            set_seed(seed)

            print(
                f"\n >>> Branching Experiment: Lambda={lambda_val} (Seed={seed}) <<< \n"
            )

            final_generator = run_recursive_loop(
                lambda_val=lambda_val,
                seed=seed,
                init_g_state=base_g_state,
                init_d_state=base_d_state,
                oracle_model=oracle_model,
                original_loader=original_loader,
                original_dataset=original_dataset,
            )

            final_generators[f"lambda_{lambda_val}"] = final_generator

        # <--- 新增: 生成特定代数的对比图片
        print("\n--- Generating comparison images ---")
        fixed_z = torch.randn(
            CONFIG["n_comparison_samples"], CONFIG["z_dim"], device=CONFIG["device"]
        )

        # 加载第10代的模型并生成对比
        if 10 <= CONFIG["total_generations"]:
            gen_10_generators = {}
            for lambda_val in lambda_settings:
                checkpoint_path = (
                    f"checkpoints/gen_10_lambda_{lambda_val}_seed_{seed}.pt"
                )
                if os.path.exists(checkpoint_path):
                    gen = Generator().to(CONFIG["device"])
                    checkpoint = torch.load(checkpoint_path)
                    gen.load_state_dict(checkpoint["generator_state"])
                    gen_10_generators[f"lambda_{lambda_val}"] = gen

            if len(gen_10_generators) == 2:
                grid_img, save_path = generate_comparison_images(
                    gen_10_generators, 10, seed, fixed_z
                )
                print(f"Generation 10 comparison saved to: {save_path}")

        # 生成最终代的对比
        grid_img, save_path = generate_comparison_images(
            final_generators, CONFIG["total_generations"], seed, fixed_z
        )
        print(f"Final generation comparison saved to: {save_path}")

        # <--- 新增: 清理GPU内存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\nAll experiments finished!")

    # <--- 新增: 显示GPU内存使用情况
    if torch.cuda.is_available():
        print(f"Final GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Final GPU memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
