import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse

"""
Dataloader for our scene pyg dataset

"""

class ScenePyGLoader(Dataset):
    def __init__(self, dataset_dir, text_cache_dir=None):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.text_cache_dir = text_cache_dir
        # we search all the .npy files, and append them to the list
        self.data_files = [f for f in os.listdir(dataset_dir) if f.endswith(".npy")]

    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        return {"data": np.load(os.path.join(self.dataset_dir, self.data_files[idx]), allow_pickle=True).item(),
                "base_path": self.dataset_dir,
                "cache_base_dir": self.text_cache_dir}
    
    @property
    def label_len(self):
        data = np.load(os.path.join(self.dataset_dir, self.data_files[0]), allow_pickle=True).item()
        return data["one_hot_class_labels"].shape[-1]

class ScenePyGLoaderSeq(Dataset):
    def __init__(self, dataset_dir, text_cache_dir=None):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.text_cache_dir = text_cache_dir
        # we search all the .npy files, and append them to the list
        self.data_files = [f for f in os.listdir(dataset_dir) if f.endswith(".npy")]

    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        return {"data": np.load(os.path.join(self.dataset_dir, f"{idx}.npy"), allow_pickle=True).item(),
                "base_path": self.dataset_dir,
                "cache_base_dir": self.text_cache_dir}
    
    @property
    def label_len(self):
        data = np.load(os.path.join(self.dataset_dir, self.data_files[0]), allow_pickle=True).item()
        return data["one_hot_class_labels"].shape[-1]

class ScenePyGLoader_LLM(Dataset):
    def __init__(self, dataset_dir, text_cache_dir=None):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.text_cache_dir = text_cache_dir
        # we search all the .npy files, and append them to the list
        self.data_files = sorted([f for f in os.listdir(dataset_dir) if f.endswith(".npy")])

    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        return {"data": np.load(os.path.join(self.dataset_dir, self.data_files[idx]), allow_pickle=True).item(),
                "base_path": self.dataset_dir,
                "cache_base_dir": self.text_cache_dir, 
                "load_file": self.data_files[idx]}
    
    @property
    def label_len(self):
        data = np.load(os.path.join(self.dataset_dir, self.data_files[0]), allow_pickle=True).item()
        return data["one_hot_class_labels"].shape[-1]
    

def create_fully_connected_edge_index(num_nodes, loop=True):
    """
    生成一个全连接的边索引列表。
    
    参数:
        num_nodes (int): 图中节点的数量。
        loop (bool): 是否在每个节点上添加自环。
    
    返回:
        edge_index (Tensor): 2 x E的张量，其中E是边的数量。
    """
    # 创建一个num_nodes x num_nodes的矩阵，初始化为1
    adj = torch.ones((num_nodes, num_nodes))
    
    # 如果不包括自环，则将对角线元素设置为0
    if not loop:
        adj.fill_diagonal_(0)
    
    # 将邻接矩阵转换为稀疏格式的边索引
    edge_index, _ = dense_to_sparse(adj)
    
    return edge_index

def text_ball_collate_fn_pyg(batch):
    batch_size = len(batch)
    return_data_list = [] # PyG data list
    irregular_data = {
        "label": [],
        "raw_model_path": [],
        "texture_image_path": [],
        "text_des": [],
        "furni_rel": [],
        "scene_id": [],
        "scene_centroid": []
    } # dict list
    base_path = batch[0]["base_path"]
    cache_base_dir = batch[0]["cache_base_dir"]
    for batch_idx in range(batch_size):
        # get xy data from one data
        # print all keys in the data
        # print(batch[batch_idx]["data"].keys())
        one_all_furni = batch[batch_idx]["data"]["all_furni"]
        data = get_data_xy_from_one_data(one_all_furni) # get xy data from one the furni
        data.one_hot_class_labels = torch.tensor(batch[batch_idx]["data"]["one_hot_class_labels"], dtype=torch.float32)
        # add text to the list
        irregular_data["label"].append(one_all_furni["label"])
        # irregular_data["text_des"].append(batch[batch_idx]["data"]["text_des"])
        irregular_data["furni_rel"].append(batch[batch_idx]["data"]["furni_rel"])
        irregular_data["raw_model_path"].append(one_all_furni["raw_model_path"]) # 2D irregular list
        irregular_data["texture_image_path"].append(one_all_furni["texture_image_path"]) # 2D irregular list
        irregular_data["scene_id"].append(batch[batch_idx]["data"]["scene_id"])
        irregular_data["scene_centroid"].append(batch[batch_idx]["data"]["centroid"])
        # we save each text in a txt file, and we read the text file
        text_path = batch[batch_idx]["data"]["text_path"]
        with open(os.path.join(base_path, text_path), 'r') as f:
            irregular_data["text_des"].append(f.read())
        # with open(os.path.join(base_path, text_path)) as f:
        #     return_text_list.append(f.read())
        # read the text embedding in cache
        if cache_base_dir is not None:
            text_emb_path = text_path.replace(".txt", ".pt")
            text_emb = torch.load(os.path.join(cache_base_dir, text_emb_path), map_location="cpu", weights_only=False) # similar to the data, we use cpu
            # text_emb dim is [1, feat_dim], we transform it to [num_nodes, feat_dim]
            text_emb = text_emb.expand(data.x.size(0), -1)
            # add to data
            data.text_emb = text_emb
        else:
            text_emb = None
        data.edge_index = create_fully_connected_edge_index(data.x.size(0), loop=True)
        return_data_list.append(data)
    return {"data": return_data_list, "irregular_data": irregular_data}
        
def get_data_xy_from_one_data(one_all_furni) -> Data:
    """
    In our dataloader design, all fruni data are treated as the data.x, but we need to analyze the dicts
    """
    data = Data()
    furni_pc = one_all_furni["furni_pc"] # a list, we stack them to a tensor
    furni_pc = torch.tensor(np.stack(furni_pc, axis=0), dtype=torch.float32)
    data.x = furni_pc # point cloud is the input for the GNN
    """
    Get the transformation matrix from the original data
    """
    furni_z_angle = one_all_furni["z_angle"]
    furni_z_angle = np.array(furni_z_angle)
    # transform furni_z_angle from [N] to [N, 1]
    furni_z_angle = furni_z_angle[:, np.newaxis]
    furni_translation = one_all_furni["translation"]
    furni_translation = np.array(furni_translation) # [N, 3]
    full_transformation = np.concatenate([furni_z_angle, furni_translation], axis=1)
    data.y = torch.tensor(full_transformation, dtype=torch.float32) # the learning target
    furni_loc = one_all_furni["furni_loc"]
    furni_loc = np.array(furni_loc)
    data.furni_loc = torch.tensor(furni_loc, dtype=torch.float32) # the original location of the furniture
    furni_single_pc_bbox = torch.tensor(np.array(one_all_furni["furni_single_pc_bbox"]), dtype=torch.float32)
    data.furni_single_pc_bbox = furni_single_pc_bbox
    data.scale = torch.tensor(np.array(one_all_furni["scale"]), dtype=torch.float32)
    return data

