import torch
from torch_geometric.data import Batch
import os
import pandas as pd

import yaml
from dataset.data_utils import SentenceEncoder
from dataset.task_constructor import UnifiedTaskConstructor
from tqdm import tqdm

with open("config/pt_data.yaml", "r") as stream:
    WEIGHT = yaml.safe_load(stream)

datasets = {k: v.keys() for k, v in WEIGHT.items()}


def refine_dataset(dataset):
    # works for molecule graphs
    if dataset.data.get("node_embs") is not None:
        dataset.data.node_text_feat = dataset.data.node_embs
    if dataset.data.get("edge_embs") is not None:
        dataset.data.edge_text_feat = dataset.data.edge_embs
    if dataset.data.get("pretrain_edge_index") is not None:
        dataset.data.edge_index = dataset.data.pretrain_edge_index
    return dataset


def filter_unnecessary_attrs(dataset):
    keys = [
        "x",
        "xe",
        "edge_index",
        "node_text_feat",
        "edge_text_feat",
        "field"
    ]

    for k, v in dataset.data.to_dict().items():
        if k not in keys:
            dataset.data[k] = None
    return dataset


def span_node_and_edge_idx(dataset):
    # Define node index, 将原始的emb抹去，换成类似生物图的节点编号
    if dataset.data.x.is_floating_point():
        dataset.data.x = dataset.data.x.int()

    if dataset.data.x.ndim == 1:
        return dataset

    num_nodes = dataset.data.x.shape[0]
    dataset.data.x = torch.arange(num_nodes)

    # Define edge index
    num_edge_types = dataset.data.edge_text_feat.shape[0]
    if num_edge_types == 1:
        num_edges = dataset.data.edge_index.shape[1]
        dataset.data.xe = torch.zeros([num_edges], dtype=torch.long)
    else:
        dataset.data.xe = dataset.data.edge_types
    return dataset


def get_task_constructor(data_path, graph_llm_name, llm_b_size, path):
    graph_encoder = SentenceEncoder(graph_llm_name, llm_b_size, path)
    tasks = UnifiedTaskConstructor(graph_encoder, data_path)
    return tasks


def idx2mask(idx, size):
    mask = torch.zeros(size, dtype=torch.bool)
    mask[idx] = True
    return mask


def mask2idx(mask):
    return torch.where(mask)[0]


def get_pt_data(data_path, setting, graph_llm_name, llm_b_size, path):
    if isinstance(setting, list):
        dataset_names = []
        for s in setting:
            dataset_names.extend(datasets.get(s, s))
    elif isinstance(setting, str):
        dataset_names = datasets.get(setting, setting)

    print(f"Pre-training on {dataset_names}")

    tasks = get_task_constructor(data_path, graph_llm_name, llm_b_size, path)
    
    dataset_list = []

    field_map = {}
    field_num = 0

    for dataset_name in dataset_names:
        dataset = tasks.get_ofa_data(dataset_name)
        dataset = refine_dataset(dataset)
        dataset = span_node_and_edge_idx(dataset)
        dataset = filter_unnecessary_attrs(dataset)

        field = dataset.data.field
        if field not in field_map.keys():
            field_map[field] = field_num
            field_num += 1
        dataset.data.field = torch.full((dataset.data.x.size(0),), field_map[field])

        dataset_list.append(dataset.data)

    x_start, xe_start = 0, 0
    for dataset in dataset_list:
        dataset.x += x_start
        dataset.xe += xe_start
        x_start += dataset.node_text_feat.shape[0]
        xe_start += dataset.edge_text_feat.shape[0]

    pretrain_dataset = Batch.from_data_list(dataset_list)

    return pretrain_dataset


def get_train_node_idx(data, weights):
    assert data.ptr is not None  # ptr是每个图之间的划分
    # 把weight分为整数部分和小数部分，整数部分表示整个图全部训练，小数部分抽取图来训练
    total_idx = torch.tensor([], dtype=torch.long)
    for idx, (s, e) in enumerate(zip(data.ptr[:-1], data.ptr[1:])):
        arr = torch.arange(s, e)
        int_weight, mod_weight = int(weights[idx]), weights[idx] - int(weights[idx])

        left_idx = arr.repeat(int_weight)
        right_idx = arr[torch.randperm(arr.size(0))[: int(mod_weight * arr.size(0))]]
        idx = torch.cat([left_idx, right_idx])
        total_idx = torch.cat([total_idx, idx])
    return total_idx


def preprocess_split(split):
    split_list = []

    if isinstance(split["test"], list):
        for train, valid, test in zip(split["train"], split["valid"], split["test"]):
            split_list.append({"train": train, "valid": valid, "test": test})
    elif split["test"].ndim == 1:
        for train, valid in zip(split["train"], split["valid"]):
            split_list.append({"train": train, "valid": valid, "test": split["test"]})

    return split_list


def get_node_data(tasks, dataset_name):
    dataset = tasks.get_ofa_data(dataset_name)
    data = dataset[0]
    
    if dataset_name in ["cora", "pubmed"]:
        split = {"train": data.train_masks, "valid": data.val_masks, "test": data.test_masks}
        split = preprocess_split(split)
        labels = data.y
        num_classes = labels.unique().shape[0]

    elif dataset_name in ["wikics"]:
        split = {"train": data.train_mask.T, "valid": data.val_mask.T, "test": data.test_mask.T}
        split = preprocess_split(split)
        labels = data.y
        num_classes = labels.unique().shape[0]

    elif dataset_name in ["arxiv", "citeseer"]:
        split = {"train": data.train_mask, "valid": data.val_mask, "test": data.test_mask}
        labels = data.y.squeeze()
        num_classes = labels.unique().shape[0]
    
    elif dataset_name in ["instagram", "reddit"]:
        split = {"train": data.train_masks, "valid": data.val_masks, "test": data.test_masks}
        split = preprocess_split(split)
        labels = data.y
        num_classes = labels.unique().shape[0]
    
    elif dataset_name in ["bookhis", "elecomp", "elephoto", "sportsfit", "products"]:
        split = {"train": data.train_mask, "valid": data.val_mask, "test": data.test_mask}
        labels = data.y
        num_classes = labels.unique().shape[0]

    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported for node classification task")
    
    print(f"{dataset_name:} nodes: {dataset.data.x.size(0):} edges: {dataset.data.edge_index.size(1):} "
          f"num_classes: {num_classes}")

    return dataset, split, labels, num_classes, None, None


def get_link_data(tasks, dataset_name):
    if dataset_name in ["WN18RR", "FB15K237"]:
        dataset = tasks.get_ofa_data(dataset_name)
        split = tasks.get_data_split(dataset_name)

        data = dataset[0]

        labels = data.edge_types
        num_classes = labels.unique().shape[0]

    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported for link classification task")
    
    print(f"{dataset_name:} nodes: {dataset.data.x.size(0):} edges: {dataset.data.edge_index.size(1):} "
          f"num_classes: {num_classes}")

    return dataset, split, labels, num_classes, None, None


def get_graph_clf_graph(tasks, dataset_name):
    dataset = tasks.get_ofa_data(dataset_name)
    split = tasks.get_data_split(dataset_name)

    if dataset_name in ["chembace", "chembbbp", "chemhiv"]:
        num_classes = 1
        labels = dataset.y

    elif dataset_name == "chempcba":
        num_classes = 128
        labels = dataset.y.reshape(-1, num_classes)

    elif dataset_name == "chemcyp450":
        num_classes = 5
        labels = dataset.y.reshape(-1, num_classes)

    elif dataset_name == "chemmuv":
        num_classes = 17
        labels = dataset.y.reshape(-1, num_classes)

    elif dataset_name == "chemtoxcast":
        num_classes = 588
        labels = dataset.y.reshape(-1, num_classes)

    elif dataset_name == "chemtox21":
        num_classes = 12
        labels = dataset.y.reshape(-1, num_classes)
        
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported for graph classification task")

    return dataset, split, labels, num_classes, None, None


def get_GQA_graph(tasks,dataset_name):
    dataset = tasks.get_ofa_data(dataset_name)
    split = tasks.get_data_split(dataset_name)
    
    if dataset_name == "scene_graphs":
        train_idx = split['train'].index.tolist()
        val_idx = split['val'].index.tolist()
        test_idx = split['test'].index.tolist()
        split = {
            "train": train_idx,
            "val": val_idx,
            "test": test_idx,
        }
    question = dataset.side_data[1]
    desc = dataset.side_data[2]
    label = dataset.side_data[3]
    
    # # create finetune_graph for scene_graphs
    # if not os.path.exists("/root/autodl-tmp/data/GQA/scene_graphs/finetune_graph.pt") and dataset_name == "scene_graphs":
    #     process_bar = tqdm(total = 100000)
    #     process_bar.set_description("create finetune_graph")
    #     finetune_dataset = []
    #     for i in range(100000):
    #         graph_index = dataset.df.iloc[i]["image_id"]
    #         data = torch.load("data/GQA/scene_graphs/graphs/{}.pt".format(graph_index))
            
    #         node_feat = dataset.data.node_text_feat[data.x.numpy()]
    #         edge_feat = dataset.data.edge_text_feat[data.xe.numpy()]
    #         data.node_text_feat = node_feat
    #         data.edge_text_feat = edge_feat
    #         data.field = 5
            
    #         nodes = pd.read_csv(f'/root/autodl-tmp/data/GQA/scene_graphs/nodes/{graph_index}.csv')
    #         edges = pd.read_csv(f'/root/autodl-tmp/data/GQA/scene_graphs/edges/{graph_index}.csv')
    #         data.question = dataset.side_data[1][i]
    #         data.desc = nodes.to_csv(index=False)+'\n'+edges.to_csv(index=False)
    #         data.y = dataset.side_data[3][i]
    #         data.full_answer = dataset.side_data[4][i]
    #         finetune_dataset.append(data)
    #         process_bar.update(1)
    #     torch.save(finetune_dataset,'/root/autodl-tmp/data/GQA/scene_graphs/finetune_graph.pt')
    
    # # load finetune_graphs
    # if dataset_name == "scene_graphs":
    #     dataset = torch.load('/root/autodl-tmp/data/GQA/scene_graphs/finetune_graph.pt')
        
    return dataset, split, label, None, question, desc


def get_finetune_graph(task, data_path, dataset_name, graph_llm_name, llm_b_size, path):
    tasks = get_task_constructor(data_path, graph_llm_name, llm_b_size, path)

    if task == 'node':
        return get_node_data(tasks, dataset_name)
    elif task == 'link':
        return get_link_data(tasks, dataset_name)
    elif task == 'graph':
        return get_graph_clf_graph(tasks, dataset_name)
    elif task == 'GQA':
        return get_GQA_graph(tasks, dataset_name)
    else:
        raise NotImplementedError(f"Dataset {dataset_name} is not supported")
