
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import torch
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import argparse
from models.ESA_try4 import UpdatedModel
import torch.nn.functional as F
from collections import defaultdict
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tuning CLIP with Co-Purchase Dataset")

    # 文件路径参数
    parser.add_argument("--output_path", type=str, default="./result/edge_prediction_relgraph",
                        help="Path to save the evaluation results")
    parser.add_argument("--model_select", type=str, default="UpdatedModel")
    parser.add_argument("--model_path", type=str, default="./model_output/jointly_train_degree_mix_new_3.pt")

    # 模型与训练超参数
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument("--num_workers", type=int, default=4)
    return parser.parse_args()

def load_image(asin, image_folder):
    path = os.path.join(image_folder, f"{asin}_MAIN.jpg")
    try:
        return Image.open(path).convert("RGB")
    except:
        return Image.new("RGB", (224, 224), (0, 0, 0))

class TestDataset(Dataset):
    def __init__(self, clip_processor, test_samples, image_folder, asin2title):
        self.clip_processor = clip_processor
        self.image_folder = image_folder
        self.test_samples = test_samples
        self.asin2title = asin2title

    def __len__(self):
        return len(self.test_samples)

    def __getitem__(self, idx):
        sample = self.test_samples[idx]
        asin1 = sample["node"]
        asin2 = sample["positive_sample"]
        edge_info = int(sample["edge_info"])

        images = []
        titles = []
        asins = [asin1, asin2]

        images.append(load_image(asin1, self.image_folder))
        images.append(load_image(asin2, self.image_folder))
        titles.append(self.asin2title.get(asin1))
        titles.append(self.asin2title.get(asin2))

        inputs = self.clip_processor(text=titles, images=images, return_tensors="pt", padding="max_length", truncation=True,
                                         max_length=77)

        return {
            "edge_info": edge_info,
            "asins": asins,
            "inputs": inputs
        }


def collate_fn(batch):
    batch_group_ids = [b["edge_info"] for b in batch]  # 每组的 ID
    batch_asins = [b["asins"] for b in batch]  # [batch_size, 22]
    input_ids = torch.stack([b["inputs"]["input_ids"] for b in batch])  # [batch_size, 22, seq_len]
    attention_mask = torch.stack([b["inputs"]["attention_mask"] for b in batch])
    pixel_values = torch.stack([b["inputs"]["pixel_values"] for b in batch])

    return (
        batch_group_ids,
        batch_asins,
        {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values
        }
    )

def compute_similarity_matrix(text_fea, image_fea, sim_type='avg'):
    batch_size, edge_num, _, dim = text_fea.shape
    sims = []
    for b in range(batch_size):
        sims_b = []
        for e in range(edge_num):
            t1, t2 = text_fea[b, e, 0], text_fea[b, e, 1]
            i1, i2 = image_fea[b, e, 0], image_fea[b, e, 1]
            if sim_type == 'text':
                sim = F.cosine_similarity(t1, t2, dim=0).item()
            elif sim_type == 'image':
                sim = F.cosine_similarity(i1, i2, dim=0).item()
            elif sim_type == 'text_img':
                sim = F.cosine_similarity(t1, i2, dim=0).item()
            elif sim_type == 'img_text':
                sim = F.cosine_similarity(i1, t2, dim=0).item()
            elif sim_type in ['avg', 'text_img_avg']:
                sim = F.cosine_similarity((t1 +i1),(t2 + i2), dim=0).item()
            else:
                raise ValueError(f"Unknown sim_type: {sim_type}")
            sims_b.append(sim)
        sims.append(sims_b)
    return np.array(sims)  # [batch_size, 11]

def save_confusion_matrix_txt(matrix, sim_type, output_path, dataset_base):
    file_path = os.path.join(output_path, f"{dataset_base}_{sim_type}_confusion_matrix.txt")
    with open(file_path, "w") as f:
        # 打印表头
        f.write("GT \\ Pred\t" + "\t".join([f"{i:>3}" for i in range(len(matrix))]) + "\n")
        f.write("-" * (10 + 5 * len(matrix)) + "\n")
        for i, row in enumerate(matrix):
            row_str = "\t".join(f"{val:>3}" for val in row)
            f.write(f"{i:>3}\t{row_str}\n")

def top_k_accuracy(sim_matrix, y_true, k=3):
    topk_preds = np.argsort(sim_matrix, axis=1)[:, -k:]  # 每行取后 k 个最大的 index
    correct = sum(y_true[i] in topk_preds[i] for i in range(len(y_true)))
    return correct / len(y_true)

def main():
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 数据集路径
    dataset_names = [
        "Baby_Products_new", "Electronics_new",
        "Musical_Instruments_new", "Automotive_new", "Office_Products_new",
        "Pet_Supplies_new", "amazon_sports_new"
    ]
    dataset_paths = [os.path.join("../dataset", name) for name in dataset_names]
    os.makedirs(args.output_path, exist_ok=True)

    # 加载模型
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    embed_dim = clip_model.config.projection_dim
    if args.model_select == 'UpdatedModel':
        model = UpdatedModel(clip_model, embed_dim).to(device)
        model.load_state_dict(torch.load(args.model_path, map_location=device))
    else:
        raise ValueError("A valid model must be selected.")
    model.eval()

    sim_types = ['text', 'image', 'text_img', 'img_text', 'text_img_avg']

    # 读取多个数据集的测试数据
    for dataset in dataset_paths:
        # 初始化存储测试数据的列表
        test_samples = []
        asin2title = {}
        print(f"🔄 Loading test data from {dataset}...")
        test_data_file = os.path.join(dataset, "test_data.json")
        dataset_base = os.path.basename(dataset.replace("_new", ""))
        item_file = os.path.join(dataset, f"{dataset_base}_item_test.jsonl")
        image_folder = os.path.join(dataset.replace("_new", ""), "images")
        cluster_file = os.path.join(dataset, "cluster_10_description.json")

        # 读取商品信息（asin到title的映射）
        if os.path.exists(item_file):
            with open(item_file, "r") as f:
                for line in f:
                    item = json.loads(line)
                    asin2title[item['parent_asin']] = item['title']

        # 读取测试数据
        if os.path.exists(test_data_file):
            with open(test_data_file, "r", encoding="utf-8") as f:
                for i, line in enumerate(f):
                    test_samples.append(json.loads(line.strip()))

        # # 读取 cluster 描述

        with open(cluster_file, "r", encoding="utf-8") as f:
            edge_clusters = json.load(f)

        edge_clusters = {
            cluster_id: text
            for cluster_id, text in edge_clusters.items()
            if int(''.join(filter(str.isdigit, cluster_id))) < 10
        }

        edge_inputs = {
            cluster_id: {
                key: value.to(device) for key, value in
                clip_processor(text=text, return_tensors="pt", padding=True, truncation=True).items()
            }
            for cluster_id, text in edge_clusters.items()
        }

        dataset = TestDataset(clip_processor, test_samples, image_folder, asin2title)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
                                collate_fn=collate_fn)

        all_preds = {sim_type: [] for sim_type in sim_types}
        all_scores = {sim_type: [] for sim_type in sim_types}
        all_labels = []

        for edge_info, asins, inputs in tqdm(dataloader, desc="Evaluating"):
            inputs = {key: val.squeeze(0).to(device) for key, val in inputs.items()}
            batch_size, num_items, seq_len = inputs["input_ids"].shape  # 获取 batch 维度信息
            _, _, channels, height, width = inputs["pixel_values"].shape

            # 变形适应模型
            inputs["input_ids"] = inputs["input_ids"].view(batch_size * num_items, seq_len)
            inputs["attention_mask"] = inputs["attention_mask"].view(batch_size * num_items, seq_len)
            inputs["pixel_values"] = inputs["pixel_values"].view(batch_size * num_items, channels, height, width)

            # 计算特征
            with torch.no_grad():
                text_embeddings, image_embeddings, text_out_list, image_out_list = model(edge_inputs, inputs)
                edge_num = len(text_out_list)

                text_fea = torch.stack(text_out_list)
                text_fea = text_fea.view(edge_num, batch_size, num_items, embed_dim)
                text_fea = text_fea.permute(1, 0, 2, 3)  # (64, 11, 2, 512)

                image_fea = torch.stack(image_out_list)
                image_fea = image_fea.view(edge_num, batch_size, num_items, embed_dim)
                image_fea = image_fea.permute(1, 0, 2, 3)  # (64, 11, 2, 512)

                for sim_type in sim_types:
                    sim_matrix = compute_similarity_matrix(text_fea, image_fea, sim_type=sim_type)
                    preds = sim_matrix.argmax(axis=1).tolist()
                    all_preds[sim_type].extend(preds)
                    all_scores[sim_type].append(sim_matrix)

                all_labels.extend(edge_info)

        results = defaultdict(dict)

        print("🔍 Final Evaluation Metrics:")
        for sim_type in sim_types:
            y_pred = all_preds[sim_type]
            y_true = all_labels

            sim_matrix_full = np.concatenate(all_scores[sim_type], axis=0)  # shape [N, 11]
            top3 = top_k_accuracy(sim_matrix_full, y_true, k=3)
            top5 = top_k_accuracy(sim_matrix_full, y_true, k=5)

            acc = accuracy_score(y_true, y_pred)
            f1_macro = f1_score(y_true, y_pred, average='macro')
            f1_weighted = f1_score(y_true, y_pred, average='weighted')
            cm = confusion_matrix(y_true, y_pred)


            results[sim_type] = {
                "accuracy": acc,
                "top3_accuracy": top3,
                "top5_accuracy": top5,
                "macro_f1": f1_macro,
                "weighted_f1": f1_weighted,  # 转 list 方便 json 保存
                "report": classification_report(y_true, y_pred, output_dict=True)
            }

            print(f"\n=== {sim_type.upper()} Similarity ===")
            print(f"Accuracy: {acc:.4f}")
            print(f"Top-3 Accuracy: {top3:.4f}")
            print(f"Top-5 Accuracy: {top5:.4f}")
            print(f"Macro F1: {f1_macro:.4f}")
            print(f"Weighted F1: {f1_weighted:.4f}")

        # 保存
        save_confusion_matrix_txt(cm, sim_type, args.output_path, dataset_base)
        save_file = os.path.join(args.output_path, f"{dataset_base}_result.jsonl")
        with open(save_file, "w") as f:
            json.dump(results, f, indent=4)
        print(f"\n✅ Saved all metrics to {save_file}")


if __name__ == "__main__":
    main()