import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import clip
import random
from dataset import DegradationDataset
from model import Classification
from tqdm import tqdm

from tqdm import tqdm


def train(model, dataloader, optimizer, device, model_clip, epoch):
    model.train()
    running_loss = 0.0
    running_cos_loss = 0.0
    running_mse_loss = 0.0
    cos_loss = 0.0
    mse_loss = 0.0
    num_epochs = 10
    # 使用 tqdm 包裹 dataloader 以显示进度条
    with tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
        for batch_idx, (img_tensor, text_tensor) in enumerate(pbar):
            img_tensor = img_tensor.to(device)
            text = clip.tokenize(text_tensor).to(device)

            # 清空梯度
            optimizer.zero_grad()
            text_encoded = model_clip.encode_text(text).to(device)

            # 前向传播
            fre = model(img_tensor)  # 获取模型的输出
            fre = fre.float()
            text_encode = text_encoded.float()
            cos_sim = F.cosine_similarity(fre, text_encode, dim=-1)  # 计算每对向量的余弦相似度

            # 计算损失：1 - cos_sim，目的是使得相似度尽量接近 1
            loss_cosine = 1 - cos_sim.mean()  # 对所有样本的平均余弦相似度进行计算
            # 总损失
            loss = loss_cosine

            # 反向传播
            loss.backward()

            # 更新参数
            optimizer.step()
            # cos_loss += loss_cosine.item()
            # mse_loss += loss.item()
            # 更新累计损失
            running_loss += loss.item()
            running_cos_loss += loss_cosine.item()
            # running_mse_loss += loss.item() - loss_cosine.item()
            # 更新进度条中的信息
            pbar.set_postfix(loss=loss.item(), avg_loss=running_loss / (batch_idx + 1), avg_cos_loss=running_cos_loss / (batch_idx + 1))

    # 输出当前epoch的平均损失
    avg_loss = running_loss / len(dataloader)
    print(f"\nEpoch {epoch + 1} - Average Loss: {avg_loss:.4f}")

    return avg_loss


def main():
    # 配置
    img_dir = "./Classification_dataset"  # 替换为你的图像路径
    batch_size = 8
    num_epochs = 10
    target_size = (224, 224)
    learning_rate = 1e-4

    # 设备配置 (GPU 或 CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 数据集和数据加载器
    dataset = DegradationDataset(img_dir, target_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # 加载CLIP模型
    model_clip, _ = clip.load("ViT-L/14@336px", device=device)

    # 定义模型
    model = Classification().to(device)

    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)


    # 训练过程
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        epoch_loss = train(model, dataloader, optimizer, device, model_clip, epoch)

        # 每个epoch结束后可以保存模型
        if (epoch + 1) % 5 == 0:  # 每5个epoch保存一次
            torch.save(model.state_dict(), f"abmodel_epoch_{epoch+1}.pth")
            print(f"Model saved after epoch {epoch+1}")


if __name__ == "__main__":
    main()
