
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import json
import torch
print("Available GPUs:", torch.cuda.device_count())
print("Current GPU:", torch.cuda.current_device())
print("GPU Name:", torch.cuda.get_device_name(0))
from PIL import Image
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import argparse
from models.ESA_try4 import UpdatedModel
print("Available GPUs:", torch.cuda.device_count())
print("Current GPU:", torch.cuda.current_device())
def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tuning CLIP with Co-Purchase Dataset")

    # 文件路径参数

    parser.add_argument("--output_path", type=str, default="./result/jointly_train_degree_epoch_3",
                        help="Path to save the evaluation results")
    parser.add_argument("--model_select", type=str, default="UpdatedModel")
    parser.add_argument("--model_path", type=str, default="./model_output/jointly_train_degree_epoch_3.pt")

    # 模型与训练超参数
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
    parser.add_argument("--num_workers", type=int, default=2)
    return parser.parse_args()

class TestDataset(Dataset):
    def __init__(self, clip_processor, test_samples, image_folder, asin2title):
        self.clip_processor = clip_processor
        self.image_folder = image_folder
        self.test_samples = test_samples
        self.asin2title = asin2title

    def __len__(self):
        return len(self.test_samples)

    def __getitem__(self, idx):
        sample = self.test_samples[idx]
        node = sample["node"]
        edge_info = sample["edge_info"]
        positive = sample["positive_sample"]
        negatives = sample["negative_samples"]
        asins = [node, positive] + negatives
        images = []
        titles = []

        for asin in asins:
            image_path = os.path.join(self.image_folder, f"{asin}_MAIN.jpg")
            try:
                # 保留为 PIL.Image.Image 格式
                img = Image.open(image_path).convert("RGB")
                images.append(img)
                title = self.asin2title.get(asin)
                titles.append(title)
            except Exception as e:
                print(f"Error loading image: {image_path}, Error: {e}")
                # 使用占位图像，仍保留为 PIL.Image.Image 格式
                placeholder = Image.new("RGB", (224, 224), (0, 0, 0))
                images.append(placeholder)
        inputs = self.clip_processor(text=titles, images=images, return_tensors="pt", padding="max_length", truncation=True,
                                         max_length=77)

        return {
            "edge_info": edge_info,
            "asins": asins,
            "inputs": inputs
        }

def compute_metrics(text_fea, image_fea, edge_info, hit_at_k_dict, mrr_scores_dict, ndcg_scores_dict, k_values=[1, 5, 10], ndcg_k=5):
    """
    计算 Hit@K, MRR, NDCG@K 指标
    """
    batch_size,edge_num, num_items, embed_dim = text_fea.shape
    assert num_items == 22  # 保证数据符合格式

    for i in range(batch_size):
        edge_idx = int(edge_info[i])
        query_text = text_fea[i, edge_idx, 0]  # 选择第 `edge_info[i]` 个 feature 作为 query
        all_text = text_fea[i, edge_idx, 1:] # 去掉 query，作为 negative set
        query_image = image_fea[i, edge_idx, 0]
        all_image = image_fea[i, edge_idx, 1:]

        # 计算余弦相似度
        text_sim = cosine_similarity(query_text.unsqueeze(0).cpu().numpy(), all_text.cpu().numpy())[0]
        image_sim = cosine_similarity(query_image.unsqueeze(0).cpu().numpy(), all_image.cpu().numpy())[0]
        text_img_sim = cosine_similarity(query_text.unsqueeze(0).cpu().numpy(), all_image.cpu().numpy())[0]
        img_text_sim = cosine_similarity(query_image.unsqueeze(0).cpu().numpy(), all_text.cpu().numpy())[0]
        text_img_avg = (text_sim + image_sim) / 2


        # 存储所有相似度计算方式的结果
        sim_results = {
            "text_sim": text_sim,
            "image_sim": image_sim,
            "text_img_sim": text_img_sim,
            "img_text_sim": img_text_sim,
            "text_img_avg": text_img_avg,
        }

        positive_idx = 0  # positive node 在 `all_text/all_image` 里的索引是 0

        for sim_type, final_sim in sim_results.items():
            sorted_indices = np.argsort(final_sim)[::-1]  # 从大到小排序

            # 计算 Hit@K
            for k in k_values:
                if positive_idx in sorted_indices[:k]:
                    hit_at_k_dict[sim_type][k] += 1

            # 计算 MRR
            rank = np.where(sorted_indices == positive_idx)[0][0] + 1
            mrr_scores_dict[sim_type].append(1 / rank)

            # 计算 NDCG@K
            dcg = 1 / np.log2(rank + 1) if rank <= ndcg_k else 0
            idcg = 1 / np.log2(1 + 1)
            ndcg = dcg / idcg if idcg > 0 else 0
            ndcg_scores_dict[sim_type].append(ndcg)

def collate_fn(batch):
    batch_group_ids = [b["edge_info"] for b in batch]  # 每组的 ID
    batch_asins = [b["asins"] for b in batch]  # [batch_size, 22]
    input_ids = torch.stack([b["inputs"]["input_ids"] for b in batch])  # [batch_size, 22, seq_len]
    attention_mask = torch.stack([b["inputs"]["attention_mask"] for b in batch])
    pixel_values = torch.stack([b["inputs"]["pixel_values"] for b in batch])

    return (
        batch_group_ids,
        batch_asins,
        {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values
        }
    )

args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载商品信息
print("Loading item title...")

# dataset_names = [
#     "Baby_Products_new", "Electronics_new", "Industrial_and_Scientific_new",
#     "Musical_Instruments_new", "Automotive_new", "Office_Products_new",
#     "Video_Games_new", "Pet_Supplies_new", "Handmade_Products_new", "amazon_sports_new"
# ]
dataset_names = ["amazon_sports_new"
]

# 数据集路径
dataset_paths = [os.path.join("../dataset", name) for name in dataset_names]
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# 加载模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
embed_dim = clip_model.config.projection_dim
os.makedirs(args.output_path, exist_ok=True)

# 读取多个数据集的测试数据
for dataset in dataset_paths:
    # 初始化存储测试数据的列表
    total_samples = 0
    test_samples = []
    asin2title = {}
    print(f"🔄 Loading test data from {dataset}...")
    test_data_file = os.path.join(dataset, "test_data.json")
    dataset_base = os.path.basename(dataset.replace("_new", ""))
    item_file = os.path.join(dataset, f"{dataset_base}_item_test.jsonl")
    image_folder = os.path.join(dataset.replace("_new", ""), "images")
    cluster_file = os.path.join(dataset, "cluster_10_description.json")

    # 读取商品信息（asin到title的映射）
    if os.path.exists(item_file):
        with open(item_file, "r") as f:
            for line in f:
                item = json.loads(line)
                asin2title[item['parent_asin']] = item['title']

    # 读取测试数据
    if os.path.exists(test_data_file):
        with open(test_data_file, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                # if i >= 3000:
                #     break
                test_samples.append(json.loads(line.strip()))
    total_samples += len(test_samples)
    # 读取 cluster 描述
    with open(cluster_file, "r", encoding="utf-8") as f:
        edge_clusters = json.load(f)

    edge_inputs = {
        cluster_id: {
            key: value.to(device) for key, value in
            clip_processor(text=text, return_tensors="pt", padding=True, truncation=True).items()
        }
        for cluster_id, text in edge_clusters.items()
    }

    dataset = TestDataset(clip_processor, test_samples, image_folder, asin2title)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
                            collate_fn=collate_fn)
    if args.model_select == 'UpdatedModel':
        model = UpdatedModel(clip_model, embed_dim).to(device)
        model.load_state_dict(torch.load(args.model_path, map_location=device))
        print("Loaded pre-trained weights for UpdatedModel.")
    else:
        raise ValueError("A valid model must be selected.")

    # 评估 Hit@K 和 MRR
    sim_methods = ["text_sim", "image_sim", "text_img_sim", "img_text_sim", "text_img_avg"]
    hit_at_k_dict = {sim: {1: 0, 5: 0, 10: 0} for sim in sim_methods}
    mrr_scores_dict = {sim: [] for sim in sim_methods}
    ndcg_scores_dict = {sim: [] for sim in sim_methods}

    for edge_info, asins, inputs in tqdm(dataloader, desc="Evaluating"):
        inputs = {key: val.squeeze(0).to(device) for key, val in inputs.items()}

        batch_size, num_items, seq_len = inputs["input_ids"].shape  # 获取 batch 维度信息
        _, _, channels, height, width = inputs["pixel_values"].shape

        # 变形适应模型
        inputs["input_ids"] = inputs["input_ids"].view(batch_size * num_items, seq_len)
        inputs["attention_mask"] = inputs["attention_mask"].view(batch_size * num_items, seq_len)
        inputs["pixel_values"] = inputs["pixel_values"].view(batch_size * num_items, channels, height, width)

        # 计算特征
        with torch.no_grad():
            text_embeddings, image_embeddings, text_out_list, image_out_list = model(edge_inputs, inputs)
            edge_num = len(text_out_list)
            text_fea = torch.stack(text_out_list)
            text_fea = text_fea.view(edge_num, batch_size, num_items, embed_dim)
            text_fea = text_fea.permute(1, 0, 2, 3)  # (64, 5, 22, 512)

            image_fea = torch.stack(image_out_list)
            image_fea = image_fea.view(edge_num, batch_size, num_items, embed_dim)
            image_fea = image_fea.permute(1, 0, 2, 3)  # (64, 5, 22, 512)
        # 计算指标
        compute_metrics(text_fea, image_fea, edge_info, hit_at_k_dict, mrr_scores_dict, ndcg_scores_dict)

    # 计算最终指标
    total_samples = total_samples
    hit_at_k_dict = {sim: {k: v / total_samples for k, v in hit_k.items()} for sim, hit_k in hit_at_k_dict.items()}
    mrr_dict = {sim: np.mean(mrr) for sim, mrr in mrr_scores_dict.items()}
    ndcg_5_dict = {sim: np.mean(ndcg) for sim, ndcg in ndcg_scores_dict.items()}

    # 打印结果
    for sim in sim_methods:
        print(f"=== {sim}:{dataset} ===")
        print(f"Hit@1: {hit_at_k_dict[sim][1]:.4f}")
        print(f"Hit@5: {hit_at_k_dict[sim][5]:.4f}")
        print(f"Hit@10: {hit_at_k_dict[sim][10]:.4f}")
        print(f"MRR: {mrr_dict[sim]:.4f}")
        print(f"NDCG@5: {ndcg_5_dict[sim]:.4f}")

    # 保存结果
    results = {
        sim: {
            "Hit@1": hit_at_k_dict[sim][1],
            "Hit@5": hit_at_k_dict[sim][5],
            "Hit@10": hit_at_k_dict[sim][10],
            "MRR": mrr_dict[sim],
            "NDCG@5": ndcg_5_dict[sim]
        }
        for sim in sim_methods
    }

    with open(os.path.join(args.output_path, f"{dataset_base}_result.jsonl"), "w") as f:
        json.dump(results, f, indent=4)

    print(f"Results saved to {args.output_path}")