# 这个代码和jointly_train.py的区别在于，这个代码在训练时，会同时训练每个用户的度数信息,降低高度数的用户的权重
# 这个和trian_degree的区别在于 dataset不是按照顺序训练的 不同dataset的batch被随机了
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from tqdm import tqdm
import torch
from torch import nn
from transformers import CLIPProcessor, CLIPModel
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import pickle
import argparse
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from PIL import Image
torch.set_printoptions(threshold=100000, edgeitems=1000, linewidth=1000)
import json

def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tuning CLIP with Co-Purchase Dataset")

    # 文件路径参数
    parser.add_argument("--batch_dir", type=str, default="/home/yqiao47/dataset/mixed/preprocessed_batches_edge_train")
    parser.add_argument("--output_dir", type=str, default="/home/yqiao47/multimodal_graph/model_output", help="Path to save the model output")
    parser.add_argument("--learning_rate", type=float, default=1e-5, help="Initial learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer")
    parser.add_argument("--lambda_weight", type=float, default=1, help="Weight for co-purchase loss")
    parser.add_argument("--num_epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--patience", type=int, default=3, help="Early stopping patience")
    parser.add_argument("--cosine_eta_min", type=float, default=1e-8, help="Minimum learning rate for cosine annealing scheduler")
    parser.add_argument("--cosine_t_max", type=int, default=20, help="T_max for cosine annealing scheduler")
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--batch_size", type=int, default=5803)
    parser.add_argument("--min_delta", type=float, default=0.01, help="Minimum delta for early stopping")
    parser.add_argument("--neg_factor", type=int, default=15, help="Negative sampling factor")
    parser.add_argument("--temperature", type=float, default= 0.1, help="Temperature for NT-Xent loss")

    return parser.parse_args()

class MixedBatchDataset(Dataset):
    def __init__(self, mixed_batch_dir, neg_factor, edge_text_all, image_root="/home/yqiao47/dataset"):
        self.batch_files = [os.path.join(mixed_batch_dir, f) for f in os.listdir(mixed_batch_dir) if f.endswith(".pkl")]
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.neg_factor = neg_factor
        self.edge_text_all = edge_text_all
        self.image_root = image_root

    def __len__(self):
        return len(self.batch_files)

    def __getitem__(self, idx):
        batch_path = self.batch_files[idx]
        filename = os.path.basename(batch_path)
        dataset_name = filename.split("_batch")[0]  # 提取 dataset 名

        with open(batch_path, "rb") as f:
            batch = pickle.load(f)

        # 获取批次数据
        titles = batch["titles"]
        co_purchase_label_matrix = batch["matrix"]
        node_degrees = (co_purchase_label_matrix > 0).sum(dim=1) + 1
        node_weights = 1 / (torch.sqrt(node_degrees) + 0.5)

        image_folder = os.path.join(self.image_root, dataset_name.replace("_new", ""), "images")
        images = []
        for asin in batch["asins"]:
            image_path = os.path.join(image_folder, f"{asin}_MAIN.jpg")
            try:
                img = Image.open(image_path).convert("RGB")
            except (OSError, IOError):
                print(f"Warning: Skipping corrupted image {image_path}")
                img = Image.new("RGB", (224, 224), (0, 0, 0))  # 使用纯黑色占位图像
            images.append(img)
        # 使用 CLIP 处理器对标题和图像进行编码
        inputs = self.clip_processor(text=titles, images=images, return_tensors="pt", padding=True, truncation=True,
                                     max_length=77)

        edge_text_dict = self.edge_text_all[dataset_name]
        edge_inputs = {
            cluster_id: self.clip_processor(text=text, return_tensors="pt", padding=True, truncation=True)
            for cluster_id, text in edge_text_dict.items()
        }

        # 为每个正样本采样负样本
        positive_indices = (co_purchase_label_matrix > 0).nonzero(as_tuple=True)  # 正样本对索引
        positive_clusters = co_purchase_label_matrix[positive_indices] - 1  # 获取 cluster ID（从 1-10 转换为 0-9）

        neg_samples = []
        neg_clusters = []  # 负样本的 cluster（继承正样本的 cluster)
        for idx in range(len(positive_indices[0])):
            i, j = positive_indices[0][idx], positive_indices[1][idx]
            cluster = positive_clusters[idx]  # 获取正样本的 cluster

            # 获取与节点 i 没有边的候选负样本
            neg_candidates = (co_purchase_label_matrix[i] == 0).nonzero(as_tuple=True)[0]
            neg_candidates = neg_candidates[neg_candidates != i]

            # 如果候选负样本不足，允许重复采样
            if len(neg_candidates) < self.neg_factor:
                selected_neg_indices = torch.randint(0, len(neg_candidates), (self.neg_factor,))
            else:
                selected_neg_indices = torch.randperm(len(neg_candidates))[:self.neg_factor]

            # 添加采样的负样本对 (i, k)
            neg_samples.append((
                torch.tensor([i] * self.neg_factor),  # 固定正样本节点 i
                neg_candidates[selected_neg_indices]
            ))
            neg_clusters.append(torch.full((self.neg_factor,), cluster))

        # 合并所有负样本对
        if neg_samples:
            neg_samples = (
                torch.cat([pair[0] for pair in neg_samples]),  # 所有负样本的第一个节点
                torch.cat([pair[1] for pair in neg_samples])  # 所有负样本的第二个节点
            )
            neg_clusters = torch.cat(neg_clusters)
        else:
            neg_samples = (torch.tensor([], dtype=torch.long, device=co_purchase_label_matrix.device),
                           torch.tensor([], dtype=torch.long, device=co_purchase_label_matrix.device))
            neg_clusters = torch.tensor([], dtype=torch.long, device=co_purchase_label_matrix.device)

        del batch, images
        return inputs, edge_inputs, positive_indices, positive_clusters, neg_samples, neg_clusters, node_weights, dataset_name

class CrossAttention(nn.Module):
    def __init__(self, embed_dim=512, dropout_rate=0.1):
        super(CrossAttention, self).__init__()
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)
        self.output_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.norm = nn.LayerNorm(embed_dim)
        self.lambda_param = nn.Parameter(torch.tensor(0.5))

    def forward(self, query, key, value):
        # Query, Key, Value: [batch_size, seq_len, embed_dim]
        q = self.query_proj(query)  # [batch_size, 1, embed_dim]
        k = self.key_proj(key)  # [batch_size, seq_len, embed_dim]
        v = self.value_proj(value)  # [batch_size, seq_len, embed_dim]

        # Compute attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (key.size(-1) ** 0.5)  # [batch_size, 1, seq_len]
        attn_weights = self.softmax(attn_scores)  # [batch_size, 1, seq_len]

        # Compute weighted sum of values
        attn_output = torch.matmul(attn_weights, v)  # [batch_size, 1, embed_dim]
        attn_output = self.output_proj(attn_output.squeeze(1))
        attn_output = self.dropout(attn_output)
        attn_output = F.normalize(attn_output, p=2, dim=-1)  # L2 归一化
        return attn_output

class UpdatedModel(nn.Module):
    def __init__(self, clip_model, embed_dim=512, dropout_rate=0.1):
        super(UpdatedModel, self).__init__()
        self.clip_model = clip_model
        self.clip_model.config.output_hidden_states = True
        self.text_projection = clip_model.text_projection
        self.text_encoder = clip_model.text_model
        self.cross_attention = CrossAttention(embed_dim, dropout_rate)
        self.vision_projection = clip_model.visual_projection

    def forward(self, edge_text_dict, inputs):
        # Get CLIP embeddings
        clip_outputs = self.clip_model(**inputs)
        text_embeddings = clip_outputs.text_embeds  # [batch_size, embed_dim]
        # print("text_embeddings variance:", text_embeddings.std(dim=0).mean().item())
        image_embeddings = clip_outputs.image_embeds  # [batch_size, embed_dim]
        text_outputs = clip_outputs.text_model_output.last_hidden_state
        image_outputs = clip_outputs.vision_model_output.last_hidden_state
        image_outputs = self.vision_projection(image_outputs)

        text_out_list = []
        image_out_list = []

        for edge_text in edge_text_dict.values():
            # **Encode edge information**
            edge_text = {k: v.squeeze(1) for k, v in edge_text.items()}
            edge_embeddings = self.text_encoder(**edge_text)['last_hidden_state']  # [1, embed_dim]
            edge_embeddings = edge_embeddings.repeat(text_outputs.size(0), 1, 1)  # [batch_size, 1, 512]
            eos_token_id = 49407
            eos_mask = (edge_text['input_ids'] == eos_token_id) & (edge_text['attention_mask'] == 1)
            eos_indices = eos_mask.nonzero(as_tuple=True)
            if eos_indices[0].numel() == 0:
                print("Warning: No EOS tokens found in input_ids!")
            else:
                batch_indices = eos_indices[0]  # 获取 batch 维度索引
                token_positions = eos_indices[1]  # 获取 `[EOS]` 在序列中的索引
                # 取 `[EOS]` 位置的 hidden state
                eos_token_embeddings = edge_embeddings[batch_indices, token_positions, :]
            edge_embds = self.text_projection(eos_token_embeddings)  # [batch_size, embed_dim]
            edge_embds = F.normalize(edge_embds, dim=-1)

            # **Apply cross-attention for text and image**
            updated_text_embeddings = self.cross_attention(edge_embds, text_outputs, text_outputs)
            updated_image_embeddings = self.cross_attention(edge_embds, image_outputs, image_outputs)

            # **Project updated embeddings into same space**
            text_out_list.append(updated_text_embeddings)
            image_out_list.append(updated_image_embeddings)

        return text_embeddings, image_embeddings, text_out_list, image_out_list

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    no_significant_drop = 0  # 连续未显著下降的轮次计数
    best_loss = float('inf')

    dataset_names = [
        "Baby_Products_new", "Electronics_new", "Industrial_and_Scientific_new",
        "Musical_Instruments_new", "Video_Games_new", "Pet_Supplies_new", "Handmade_Products_new","amazon_sports_new", "Automotive_new", "Office_Products_new"
    ]

    # 加载每个数据集的 edge_text（cluster_10_description.json）
    edge_text_all = {}
    for name in dataset_names:
        cluster_path = os.path.join("/home/yqiao47/dataset", name, "cluster_10_description.json")
        if os.path.exists(cluster_path):
            with open(cluster_path, "r", encoding="utf-8") as f:
                edge_text_all[name] = json.load(f)
        else:
            print(f"⚠️ Missing edge_text for {name}")

    dataset = MixedBatchDataset(
        mixed_batch_dir=args.batch_dir,
        neg_factor=args.neg_factor,
        edge_text_all=edge_text_all,
        image_root="/home/yqiao47/dataset"
    )

    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args.num_workers)

    # 之前训练的模型参数文件
    checkpoint_path = os.path.join(args.output_dir, "none.pt")
    # **加载 CLIP 模型**
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    embed_dim = clip_model.config.projection_dim
    # **初始化模型**
    model = UpdatedModel(clip_model, embed_dim).to(device)
    # **如果有之前的模型参数，加载它**
    if os.path.exists(checkpoint_path):
        print(f"🔄 Loading model parameters from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"✅ Model parameters loaded. Continuing training...")

    params = [
        {"params": model.clip_model.parameters()},
        {"params": model.cross_attention.parameters(), "lr": args.learning_rate * 10}
    ]
    # 定义优化器和学习率调度器
    optimizer = AdamW(params, lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.cosine_t_max, eta_min=args.cosine_eta_min)
    clip_loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    model.train()
    for epoch in range(args.num_epochs):
        total_loss = 0.0
        to_clip_loss = 0.0
        to_edge_loss = 0.0
        to_tri_loss = 0.0

        for inputs, edge_inputs, positive_indices, positive_clusters, neg_samples, neg_clusters, node_weights,dataset_name in tqdm(dataloader, desc="Training"):
            inputs = {key: val.squeeze(0).to(device) for key, val in inputs.items()}
            edge_inputs = {
                key: {k: v.to(device) for k, v in val.items()} for key, val in edge_inputs.items()
            }
            node_weights = node_weights.squeeze(0).to(device)
            # **正样本**
            positive_indices = (positive_indices[0].view(-1), positive_indices[1].view(-1))
            positive_clusters = positive_clusters.view(-1)  # 保证和 positive_indices 对齐
            # **负样本**
            neg_samples = (neg_samples[0].view(-1), neg_samples[1].view(-1))
            neg_clusters = neg_clusters.view(-1)  # 保证和 neg_samples 对齐

            with autocast():  # 启用自动混合精度
                # 前向传播
                text_embeddings, image_embeddings, edge_text_list, edge_image_list = model(edge_inputs, inputs)

                logits_per_text = torch.matmul(text_embeddings, image_embeddings.t())
                logits_per_image = torch.matmul(image_embeddings, text_embeddings.t())

                # CLIP 原有 Loss
                labels = torch.arange(len(logits_per_text)).to(device)
                clip_loss = (clip_loss_fn(logits_per_text, labels) + clip_loss_fn(logits_per_image, labels)) / 2

                # Edge Loss
                all_logits_per_text = []
                all_logits_per_image = []
                all_logits_text_text = []
                all_logits_image_image = []
                all_labels = []
                all_weights = []

                for cluster_id in range(len(edge_text_list)):  # 遍历 cluster
                    edge_text_embeddings = F.normalize(edge_text_list[cluster_id], dim=-1)
                    edge_image_embeddings = F.normalize(edge_image_list[cluster_id], dim=-1)
                    # **计算该 cluster 的 logits**
                    logits_edge_per_text = torch.matmul(edge_text_embeddings, edge_image_embeddings.t())
                    logits_edge_per_image = torch.matmul(edge_image_embeddings, edge_text_embeddings.t())
                    logits_edge_text_text = torch.matmul(edge_text_embeddings, edge_text_embeddings.t())
                    logits_edge_image_image = torch.matmul(edge_image_embeddings, edge_image_embeddings.t())

                    # **根据 cluster 提取正负样本**
                    cluster_mask_pos = (positive_clusters == cluster_id)  # 选择当前 cluster 的正样本
                    cluster_mask_neg = (neg_clusters == cluster_id)  # 选择当前 cluster 的负样本

                    pos_indices = (positive_indices[0][cluster_mask_pos], positive_indices[1][cluster_mask_pos])
                    # print(pos_indices[0].shape)
                    neg_indices = (neg_samples[0][cluster_mask_neg], neg_samples[1][cluster_mask_neg])
                    num_positive = pos_indices[0].shape[0]
                    # print(f"'num_positive: {num_positive}'")

                    if num_positive > 0:
                        # **提取当前 cluster 正负样本的 logits**
                        positive_similarities_text = logits_edge_per_text[pos_indices]
                        positive_similarities_image = logits_edge_per_image[pos_indices]
                        positive_similarities_text_text = logits_edge_text_text[pos_indices]
                        positive_similarities_image_image = logits_edge_image_image[pos_indices]

                        negative_similarities_text = logits_edge_per_text[neg_indices]
                        negative_similarities_image = logits_edge_per_image[neg_indices]
                        negative_similarities_text_text = logits_edge_text_text[neg_indices]
                        negative_similarities_image_image = logits_edge_image_image[neg_indices]

                        # **计算 Loss 权重**
                        weight_pos_i = node_weights[pos_indices[0]]
                        weight_pos_j = node_weights[pos_indices[1]]
                        sample_weights = (weight_pos_i + weight_pos_j) / 2

                        # **合并正负样本的 logits**
                        logits_per_text_all = torch.cat([
                            positive_similarities_text.unsqueeze(1),
                            negative_similarities_text.view(num_positive, -1)
                        ], dim=1)
                        logits_per_image_all = torch.cat([
                            positive_similarities_image.unsqueeze(1),
                            negative_similarities_image.view(num_positive, -1)
                        ], dim=1)
                        logit_text_text_all = torch.cat([
                            positive_similarities_text_text.unsqueeze(1),
                            negative_similarities_text_text.view(num_positive, -1)
                        ], dim=1)
                        logit_image_image_all = torch.cat([
                            positive_similarities_image_image.unsqueeze(1),
                            negative_similarities_image_image.view(num_positive, -1)
                        ], dim=1)

                        all_logits_per_text.append(logits_per_text_all)
                        all_logits_per_image.append(logits_per_image_all)
                        all_logits_text_text.append(logit_text_text_all)
                        all_logits_image_image.append(logit_image_image_all)
                        all_labels.append(torch.zeros(num_positive, dtype=torch.long).to(device))
                        all_weights.append(sample_weights)
                if all_logits_per_text:
                    final_logits_per_text = torch.cat(all_logits_per_text, dim=0) / args.temperature
                    final_logits_per_image = torch.cat(all_logits_per_image, dim=0) / args.temperature
                    final_logits_text_text = torch.cat(all_logits_text_text, dim=0) / args.temperature
                    final_logits_image_image = torch.cat(all_logits_image_image, dim=0) / args.temperature
                    final_labels = torch.cat(all_labels, dim=0)
                    final_weights = torch.cat(all_weights, dim=0)

                    # **计算加权 Loss**
                    text_image_loss = torch.mean(
                        final_weights * F.cross_entropy(final_logits_per_text, final_labels, reduction='none'))
                    image_text_loss = torch.mean(
                        final_weights * F.cross_entropy(final_logits_per_image, final_labels, reduction='none'))
                    text_text_loss = torch.mean(
                        final_weights * F.cross_entropy(final_logits_text_text, final_labels, reduction='none'))
                    image_image_loss = torch.mean(
                        final_weights * F.cross_entropy(final_logits_image_image, final_labels, reduction='none'))

                    edge_loss = (text_image_loss + image_text_loss) / 2
                    tri_loss = (text_text_loss + image_image_loss) / 2

                else:
                    edge_loss = torch.tensor(0.0, device=device)
                    tri_loss = torch.tensor(0.0, device=device)
                # print(f"{dataset_name}, edge_loss: {edge_loss}, clip_loss: {clip_loss}, tri_loss: {tri_loss}")

                # **计算总 Loss**
                total_batch_loss = clip_loss + args.lambda_weight * edge_loss + args.lambda_weight * tri_loss
                total_loss += total_batch_loss.item()
                to_clip_loss += clip_loss.item()
                to_edge_loss += edge_loss.item()
                to_tri_loss += tri_loss.item()

            optimizer.zero_grad()
            scaler.scale(total_batch_loss).backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=100)  # 先尝试 100

            scaler.step(optimizer)
            scaler.update()
            # del text_embeddings, image_embeddings, edge_text_list, edge_image_list
            # torch.cuda.empty_cache()

        # 每个 epoch 更新学习率
        scheduler.step()
        avg_loss = total_loss / args.batch_size
        avg_clip_loss = to_clip_loss / args.batch_size
        avg_edge_loss = to_edge_loss / args.batch_size
        avg_tri_loss = to_tri_loss / args.batch_size
        print(f"Epoch {epoch + 1}/{args.num_epochs}, Average Loss: {avg_loss:.4f}，\
        Average CLIP Loss: {avg_clip_loss:.4f}, Average edge Loss: {avg_edge_loss:.4f}, Average tri Loss: {avg_tri_loss:.4f}")

        # 保存当前epoch的模型
        model_save_path = os.path.join(args.output_dir, f"jointly_train_degree_mix_new_{epoch + 1}.pt")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model for epoch {epoch + 1} saved to {model_save_path}")

        # 检查早停条件
        print(f"best_loss - avg_loss: {best_loss - avg_loss:.4f}, no_significant_drop: {no_significant_drop}")
        if best_loss - avg_loss > args.min_delta:  # 损失下降显著
            best_loss = avg_loss
            no_significant_drop = 0  # 重置计数器
            # print(f"New best loss: {best_loss:.4f}, resetting no_significant_drop to {no_significant_drop}")
            # 保存当前最佳模型
            model_save_path = os.path.join(args.output_dir, f"jointly_train_degree_mix_new_best.pt")
            torch.save(model.state_dict(), model_save_path)
            print(f"Best model saved to {model_save_path}")
        else:
            no_significant_drop += 1
            print(
                f"No significant improvement: {best_loss - avg_loss:.4f} (threshold: {args.min_delta}), no_significant_drop: {no_significant_drop}")
            if no_significant_drop >= args.patience:  # 连续未显著下降次数达到耐心值
                print(f"Early stopping triggered at epoch {epoch + 1}")
                break

    print("Training complete.")

if __name__ == "__main__":
    main()