import numpy as np
import random as rd
import scipy.sparse as sp
from time import time
import json
from utility.parser import parse_args
import torch  # 确保导入torch

args = parse_args()

# -------------------------- 全局GPU设备配置（核心新增） --------------------------
# 检测GPU是否可用，优先使用cuda，否则用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")


class Data(object):
    def __init__(self, path, batch_size):
        self.path = path
        self.batch_size = batch_size
        self.device = device  # 实例化设备，供类内所有方法使用

        train_file = path + '/train.json'
        val_file = path + '/val.json'
        test_file = path + '/test.json'

        # get number of users and items
        self.n_users, self.n_items = 0, 0
        self.n_train, self.n_test = 0, 0
        self.n_val = 0  # 修复：初始化验证集交互数，避免后续赋值报错
        self.neg_pools = {}

        self.exist_users = []

        train = json.load(open(train_file))
        test = json.load(open(test_file))
        val = json.load(open(val_file))
        for uid, items in train.items():
            if len(items) == 0:
                continue
            uid = int(uid)
            self.exist_users.append(uid)
            self.n_items = max(self.n_items, max(items))
            self.n_users = max(self.n_users, uid)
            self.n_train += len(items)

        for uid, items in test.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_test += len(items)
            except:
                continue

        for uid, items in val.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_val += len(items)
            except:
                continue

        self.n_items += 1
        self.n_users += 1

        # 加载文本特征并校准物品数量（迁移到GPU）
        text_feats = np.load(args.data_path + args.dataset + '/text_feat.npy')
        # 转换为torch张量并移到GPU
        self.text_feats = torch.tensor(text_feats, dtype=torch.float32).to(self.device)
        self.n_items = self.text_feats.shape[0]  # 基于GPU张量获取物品数

        self.print_statistics()

        # 构建用户-物品交互矩阵（先构建scipy稀疏矩阵，后续转换为GPU张量）
        self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)
        self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32)

        self.train_items, self.test_set, self.val_set = {}, {}, {}

        # 填充交互矩阵和train_items（原有逻辑保留）
        for uid, train_items in train.items():
            if len(train_items) == 0:
                continue
            uid = int(uid)
            for idx, i in enumerate(train_items):
                self.R[uid, i] = 1.
            self.train_items[uid] = train_items

        for uid, test_items in test.items():
            uid = int(uid)
            if len(test_items) == 0:
                continue
            try:
                self.test_set[uid] = test_items
            except:
                continue

        for uid, val_items in val.items():
            uid = int(uid)
            if len(val_items) == 0:
                continue
            try:
                self.val_set[uid] = val_items
            except:
                continue

        # -------------------------- 新增：计算物品共现矩阵（迁移到GPU） --------------------------
        # 1. 构建物品-用户交互图（R的转置）
        self.iu_graph = self.R.T  # scipy稀疏矩阵
        # 2. 转换为稠密矩阵→torch张量→移到GPU
        iu_dense = torch.tensor(self.iu_graph.toarray(), dtype=torch.float32).to(self.device)
        # 3. 物品共现矩阵（GPU上矩阵乘法）
        self.co_occur = torch.matmul(iu_dense, iu_dense.T)  # 形状[n_items, n_items]，已在GPU
        # 4. 对角线置0（GPU上操作）
        self.co_occur.fill_diagonal_(0)
        print("Item co-occurrence matrix computed. Shape:", self.co_occur.shape, f"| Device: {self.co_occur.device}")

        # -------------------------- 新增：结构采样超参数（可按需调整） --------------------------
        self.struct_support_topk = 5  # 结构正样本筛选的Top-K核心物品
        self.struct_co_theta = 5  # 结构区域的共现度阈值

        # -------------------------- 语义采样相关初始化（迁移到GPU，彻底修复版） --------------------------
        import pickle
        # 加载物品语义嵌入向量文件
        embedding_path = args.data_path + args.dataset + '/augmented_atttribute_embedding_dict'
        try:
            raw_emb_dict = pickle.load(open(embedding_path, 'rb'))  # 先加载原始字典
            print(f"Successfully loaded item semantic embedding file from: {embedding_path}")
            print(f"Original embedding dict has {len(raw_emb_dict)} key-value pairs")
            
            # 清理无效键：只保留“能转成整数”且“在物品ID范围内”的键值对
            self.item_semantic_emb = {}  # 存储GPU张量
            invalid_keys = []  # 记录无效键，方便排查
            for key, emb in raw_emb_dict.items():
                try:
                    # 步骤1：尝试将键转为整数（过滤非数字键，如'year'）
                    item_id = int(key)
                    # 步骤2：验证物品ID是否在有效范围内（0 <= item_id < 总物品数）
                    if 0 <= item_id < self.n_items:
                        # 转换为torch张量并移到GPU
                        self.item_semantic_emb[item_id] = torch.tensor(emb, dtype=torch.float32).to(self.device)
                    else:
                        invalid_keys.append(f"{key} (out of item ID range: 0-{self.n_items-1})")
                except (ValueError, TypeError):
                    # 捕获无法转成整数的键（如'year'、'title'等字符串）
                    invalid_keys.append(f"{key} (cannot convert to int)")
            
            # 打印清理日志，方便用户了解情况
            print(f"Cleaned invalid keys: {len(invalid_keys)} (examples: {invalid_keys[:5] if len(invalid_keys)>5 else invalid_keys})")
            print(f"Valid semantic embedding count: {len(self.item_semantic_emb)}/{self.n_items} (total items)")
            
            # 边界：若有效嵌入为空，给出警告（不报错，避免影响整体运行）
            if len(self.item_semantic_emb) == 0:
                print("WARNING: No valid item semantic embeddings left! Semantic sampling will use fallback logic.")
            
        except FileNotFoundError:
            # 单独捕获文件不存在错误，提示更明确
            raise FileNotFoundError(f"Semantic embedding file not found at path: {embedding_path}\nPlease check if the file exists.")
        except Exception as e:
            # 捕获其他未知错误
            raise Exception(f"Error loading semantic embedding file! Details: {str(e)}")

        # 筛选有效嵌入ID（此时字典键已是整数，可直接使用）
        self.valid_semantic_item_ids = list(self.item_semantic_emb.keys())
        # -------------------------- 补全：语义采样个数属性（核心修复） --------------------------
        self.semantic_sample_num = 1  # 每个用户采样1个，与随机/结构采样保持一致
        # 其他语义采样超参数（保持不变）
        self.semantic_topk = 10  # 语义正样本候选池大小
        self.semantic_neg_thresh = 0.3  # 语义负样本筛选阈值

    def get_adj_mat(self):  # 获取图的邻接矩阵（迁移到GPU）
        try:
            t1 = time()
            adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')
            norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')
            mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')
            print('already load adj matrix', adj_mat.shape, time() - t1)
            
            # 转换为torch张量并移到GPU（稀疏张量转稠密，小数据集适用；大数据集用torch.sparse）
            adj_mat = torch.tensor(adj_mat.toarray(), dtype=torch.float32).to(self.device)
            norm_adj_mat = torch.tensor(norm_adj_mat.toarray(), dtype=torch.float32).to(self.device)
            mean_adj_mat = torch.tensor(mean_adj_mat.toarray(), dtype=torch.float32).to(self.device)

        except Exception:
            adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()
            # 保存时转换为numpy数组（scipy不支持GPU张量保存）
            sp.save_npz(self.path + '/s_adj_mat.npz', sp.csr_matrix(adj_mat.cpu().numpy()))
            sp.save_npz(self.path + '/s_norm_adj_mat.npz', sp.csr_matrix(norm_adj_mat.cpu().numpy()))
            sp.save_npz(self.path + '/s_mean_adj_mat.npz', sp.csr_matrix(mean_adj_mat.cpu().numpy()))
        return adj_mat, norm_adj_mat, mean_adj_mat

    def create_adj_mat(self):  # 构建邻接矩阵（迁移到GPU）
        t1 = time()
        # 先构建scipy稀疏矩阵
        adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
        adj_mat = adj_mat.tolil()
        R = self.R.tolil()

        adj_mat[:self.n_users, self.n_users:] = R
        adj_mat[self.n_users:, :self.n_users] = R.T
        adj_mat = adj_mat.todok()
        print('already create adjacency matrix', adj_mat.shape, time() - t1)

        t2 = time()

        def normalized_adj_single(adj):
            # adj为GPU张量
            rowsum = adj.sum(dim=1)  # torch张量求和（GPU上执行）
            d_inv = torch.pow(rowsum, -1).flatten()
            d_inv[torch.isinf(d_inv)] = 0.
            # 构建对角矩阵（GPU上）
            d_mat_inv = torch.diag(d_inv).to(self.device)
            # 矩阵乘法（GPU上）
            norm_adj = torch.matmul(d_mat_inv, adj)
            print('generate single-normalized adjacency matrix.')
            return norm_adj

        def get_D_inv(adj):
            rowsum = adj.sum(dim=1)
            d_inv = torch.pow(rowsum, -1).flatten()
            d_inv[torch.isinf(d_inv)] = 0.
            d_mat_inv = torch.diag(d_inv).to(self.device)
            return d_mat_inv

        def check_adj_if_equal(adj):
            degree = adj.sum(dim=1, keepdim=False)
            temp = torch.matmul(torch.diag(torch.pow(degree, -1)), adj)
            print('check normalized adjacency matrix whether equal to this laplacian matrix.')
            return temp

        # 转换为GPU张量后再归一化
        adj_mat_tensor = torch.tensor(adj_mat.toarray(), dtype=torch.float32).to(self.device)
        # 加单位矩阵（GPU上）
        adj_mat_tensor += torch.eye(adj_mat_tensor.shape[0]).to(self.device)
        # 归一化（全程GPU）
        norm_adj_mat = normalized_adj_single(adj_mat_tensor)
        mean_adj_mat = normalized_adj_single(torch.tensor(adj_mat.toarray(), dtype=torch.float32).to(self.device))

        print('already normalize adjacency matrix', time() - t2)
        return adj_mat_tensor, norm_adj_mat, mean_adj_mat

    def sample(self):  # 只改类型转换，不碰任何采样逻辑！
        # 1. 采样批次用户（原有逻辑完全保留，一丝不动）
        if self.batch_size <= self.n_users:
            users = rd.sample(self.exist_users, self.batch_size)
        else:
            users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]

        # 2. 所有采样子函数（原有逻辑完全保留，一行没改！）
        def sample_pos_items_for_u(u, num):
            pos_items = self.train_items[u]
            n_pos_items = len(pos_items)
            pos_batch = []
            while True:
                if len(pos_batch) == num: break
                pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
                pos_i_id = pos_items[pos_id]
                if pos_i_id not in pos_batch:
                    pos_batch.append(pos_i_id)
            return pos_batch

        def sample_neg_items_for_u(u, num):
            neg_items = []
            while True:
                if len(neg_items) == num: break
                neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
                if neg_id not in self.train_items[u] and neg_id not in neg_items:
                    neg_items.append(neg_id)
            return neg_items

        def sample_neg_items_for_u_from_pools(u, num):
            neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
            return rd.sample(neg_items, num)

        def sample_struct_pos_items_for_u(u, num):
            pos_items = self.train_items[u]
            n_total_pos = len(pos_items)
            if n_total_pos <= num:
                return pos_items[:num]
            support_scores = []
            for item_i in pos_items:
                other_pos = [item_j for item_j in pos_items if item_j != item_i]
                total_support = self.co_occur[item_i, other_pos].sum().item()
                support_scores.append((item_i, total_support))
            support_scores.sort(key=lambda x: x[1], reverse=True)
            return [item for item, _ in support_scores[:num]]

        def sample_struct_neg_items_for_u(u, num):
            pos_items = self.train_items[u]
            all_items = set(range(self.n_items))
            un交互_items = all_items - set(pos_items)
            core_pos = sample_struct_pos_items_for_u(u, num=min(self.struct_support_topk, len(pos_items)))
            struct_region = set()
            for core_item in core_pos:
                core_item_tensor = torch.tensor(core_item, dtype=torch.long).to(self.device)
                co_occur_scores = self.co_occur[core_item_tensor]
                related_items = torch.where(co_occur_scores >= self.struct_co_theta)[0].cpu().tolist()
                struct_region.update(related_items)
            struct_neg_candidates = un交互_items - struct_region
            if len(struct_neg_candidates) < num:
                fallback_candidates = [
                    item for item in un交互_items
                    if max(self.co_occur[core_item, item].item() for core_item in core_pos) == 0
                ]
                struct_neg_candidates.update(fallback_candidates)
                if len(struct_neg_candidates) < num:
                    struct_neg_candidates.update(un交互_items)
            return rd.sample(list(struct_neg_candidates), num)

        def cosine_similarity(vec1, vec2):
            dot_product = torch.matmul(vec1.unsqueeze(0), vec2.unsqueeze(1)).item()
            norm1 = torch.norm(vec1).item()
            norm2 = torch.norm(vec2).item()
            if norm1 == 0 or norm2 == 0:
                return 0.0
            return dot_product / (norm1 * norm2)

        def sample_semantic_pos_items_for_u(u, num):
            pos_items = self.train_items[u]
            all_items = set(range(self.n_items))
            un交互_items = all_items - set(pos_items)
            un交互_items = [iid for iid in un交互_items if iid in self.valid_semantic_item_ids]
            if len(un交互_items) == 0:
                return sample_pos_items_for_u(u, num)
            if len(un交互_items) <= num:
                return un交互_items[:num]
            valid_pos_emb = [self.item_semantic_emb[iid] for iid in pos_items if iid in self.valid_semantic_item_ids]
            if len(valid_pos_emb) == 0:
                return rd.sample(un交互_items, num)
            valid_pos_emb_tensor = torch.stack(valid_pos_emb).to(self.device)
            user_semantic_center = valid_pos_emb_tensor.mean(dim=0)
            item_sim_pairs = []
            for item_i in un交互_items:
                item_emb = self.item_semantic_emb[item_i]
                sim_score = cosine_similarity(user_semantic_center, item_emb)
                item_sim_pairs.append((item_i, sim_score))
            item_sim_pairs.sort(key=lambda x: x[1], reverse=True)
            top_similar_items = [item for item, sim in item_sim_pairs[:self.semantic_topk]]
            return rd.sample(top_similar_items, num)

        def sample_semantic_neg_items_for_u(u, num):
            pos_items = self.train_items[u]
            all_items = set(range(self.n_items))
            un交互_items = all_items - set(pos_items)
            un交互_items = [iid for iid in un交互_items if iid in self.valid_semantic_item_ids]
            if len(un交互_items) == 0:
                return sample_neg_items_for_u(u, num)
            if len(un交互_items) <= num:
                return un交互_items[:num]
            valid_pos_emb = [self.item_semantic_emb[iid] for iid in pos_items if iid in self.valid_semantic_item_ids]
            if len(valid_pos_emb) == 0:
                return rd.sample(un交互_items, num)
            valid_pos_emb_tensor = torch.stack(valid_pos_emb).to(self.device)
            user_semantic_center = valid_pos_emb_tensor.mean(dim=0)
            low_sim_items = []
            for item_i in un交互_items:
                item_emb = self.item_semantic_emb[item_i]
                sim_score = cosine_similarity(user_semantic_center, item_emb)
                if sim_score < self.semantic_neg_thresh:
                    low_sim_items.append(item_i)
            if len(low_sim_items) < num:
                item_sim_pairs = []
                for item_i in un交互_items:
                    item_emb = self.item_semantic_emb[item_i]
                    sim_score = cosine_similarity(user_semantic_center, item_emb)
                    item_sim_pairs.append((item_i, sim_score))
                item_sim_pairs.sort(key=lambda x: x[1])
                need = num - len(low_sim_items)
                fallback_items = [item for item, sim in item_sim_pairs[:need]]
                low_sim_items.extend(fallback_items)
            return rd.sample(low_sim_items, num)

        # 3. 生成样本列表（原有逻辑保留，一丝不动）
        random_pos, random_neg = [], []
        struct_pos, struct_neg = [], []
        semantic_pos, semantic_neg = [], []
        for u in users:
            random_pos += sample_pos_items_for_u(u, self.semantic_sample_num)
            random_neg += sample_neg_items_for_u(u, self.semantic_sample_num)
            struct_pos += sample_struct_pos_items_for_u(u, self.semantic_sample_num)
            struct_neg += sample_struct_neg_items_for_u(u, self.semantic_sample_num)
            semantic_pos += sample_semantic_pos_items_for_u(u, self.semantic_sample_num)
            semantic_neg += sample_semantic_neg_items_for_u(u, self.semantic_sample_num)

        # -------------------------- 唯一修改：列表→GPU张量（不改变任何数据/逻辑） --------------------------
        # 转换为张量（仅类型转换，数值完全不变，匹配PyTorch模型输入）
        def to_gpu_tensor(lst):
            # 直接转换为long类型GPU张量，确保主函数可调用.cpu()
            return torch.tensor(lst, dtype=torch.long, device=self.device)

        # 所有返回值统一转GPU张量，顺序和原来完全一致
        users_tensor = to_gpu_tensor(users)
        random_pos_tensor = to_gpu_tensor(random_pos)
        random_neg_tensor = to_gpu_tensor(random_neg)
        struct_pos_tensor = to_gpu_tensor(struct_pos)
        struct_neg_tensor = to_gpu_tensor(struct_neg)
        semantic_pos_tensor = to_gpu_tensor(semantic_pos)
        semantic_neg_tensor = to_gpu_tensor(semantic_neg)

        # 4. 返回GPU张量（主函数可直接用.cpu()和torch.cat，无类型错误）
        return users_tensor, random_pos_tensor, random_neg_tensor, struct_pos_tensor, struct_neg_tensor, semantic_pos_tensor, semantic_neg_tensor

    def print_statistics(self):  # 原有逻辑完全保留
        print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
        print('n_interactions=%d' % (self.n_train + self.n_test + self.n_val))
        print('n_train=%d, n_test=%d, n_val=%d, sparsity=%.5f' % (
            self.n_train, self.n_test, self.n_val,
            (self.n_train + self.n_test + self.n_val) / (self.n_users * self.n_items)
        ))