import numpy as np
import torch
import random
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import pandas as pd


def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def edge_tensor_to_set(edge_index):
    return set(
        tuple(sorted(edge.tolist()))
        for edge in edge_index.T
    )


class_map = {
    'Case_Based': 0,
    'Genetic_Algorithms': 1,
    'Neural_Networks': 2,
    'Probabilistic_Methods': 3,
    'Reinforcement_Learning': 4,
    'Rule_Learning': 5,
    'Theory': 6
}


def parse_cora():
    path = 'dataset/cora_orig/cora'
    idx_features_labels = np.genfromtxt(f"{path}.content", dtype=np.dtype(str))
    data_X = idx_features_labels[:, 1:-1].astype(np.float32)
    labels = idx_features_labels[:, -1]

    data_Y = np.array([class_map[l] for l in labels])
    data_citeid = idx_features_labels[:, 0]

    num_classes = len(class_map)
    if data_Y.min() < 0 or data_Y.max() >= num_classes:
        raise ValueError(
            f"Label values out of range. Expected [0, {num_classes - 1}], got [{data_Y.min()}, {data_Y.max()}].")

    idx_map = {j: i for i, j in enumerate(data_citeid)}
    edges_unordered = np.genfromtxt(f"{path}.cites", dtype=np.dtype(str))
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten()))).reshape(edges_unordered.shape)

    valid_mask = ~np.any(edges == None, axis=1)
    data_edges = edges[valid_mask].astype('int')

    data_edges = np.vstack((data_edges, np.fliplr(data_edges)))

    num_nodes = len(data_Y)
    if data_edges.min() < 0 or data_edges.max() >= num_nodes:
        print(
            f"Warning: Edge index out of range. Expected [0, {num_nodes - 1}], got [{data_edges.min()}, {data_edges.max()}]. Truncating invalid edges.")
        data_edges = np.clip(data_edges, 0, num_nodes - 1)

    return data_X, data_Y, data_citeid, np.unique(data_edges, axis=0).transpose()


def get_cora_casestudy(seed=0):
    data_X, data_Y, data_citeid, data_edges = parse_cora()

    dataset = Planetoid('dataset', 'cora', transform=T.NormalizeFeatures())
    data = dataset[0]

    data.x = torch.tensor(data_X).float()
    data.edge_index = torch.tensor(data_edges).long()
    data.y = torch.tensor(data_Y).long()
    data.num_nodes = len(data_Y)

    node_id = np.arange(data.num_nodes)
    np.random.shuffle(node_id)
    split = [int(data.num_nodes * 0.6), int(data.num_nodes * 0.8)]
    data.train_id = np.sort(node_id[:split[0]])
    data.val_id = np.sort(node_id[split[0]:split[1]])
    data.test_id = np.sort(node_id[split[1]:])

    data.train_mask = torch.tensor([x in data.train_id for x in range(data.num_nodes)])
    data.val_mask = torch.tensor([x in data.val_id for x in range(data.num_nodes)])
    data.test_mask = torch.tensor([x in data.test_id for x in range(data.num_nodes)])
    print(f"data_X shape: {data_X.shape}")
    print(f"data_edges shape: {data_edges.shape}")
    print(f"data_Y shape: {data_Y.shape}")
    data_edge_index = data.edge_index
    node = 2467
    outgoing_edges = data_edge_index[1, data_edge_index[0] == node]
    incoming_edges = data_edge_index[0, data_edge_index[1] == node]
    connected_nodes = torch.unique(torch.cat([outgoing_edges, incoming_edges]))
    connected_nodes = connected_nodes[connected_nodes != node]
    print(f"All nodes connected to node {node}:", connected_nodes.tolist())
    return data, data_citeid


def get_raw_text_cora(use_text=False, seed=0):
    data, data_citeid = get_cora_casestudy(seed)

    if not use_text:
        return data, None

    with open('dataset/cora_orig/mccallum/cora/papers') as f:
        lines = f.readlines()
    pid_filename = {}
    for line in lines:
        pid = line.split('\t')[0]
        fn = line.split('\t')[1]
        pid_filename[pid] = fn.replace(':', '_', 1)

    path = 'dataset/cora_orig/mccallum/cora/extractions/'
    text = []
    miss = 0
    for pid in data_citeid:
        fn = pid_filename[pid]
        try:
            with open(path + fn) as f:
                lines = f.read().splitlines()
            for line in lines:
                if 'Title:' in line:
                    ti = line
                if 'Abstract:' in line:
                    ab = line
            text.append(ti + '\n' + ab)
        except:
            text.append("Missing Title" + '\n' + "Missing Abstract")
            miss += 1

    data.edge_index = to_undirected(data.edge_index)
    transform = T.Compose([
        T.NormalizeFeatures(),
        T.ToDevice('cuda' if torch.cuda.is_available() else 'cpu'),
        T.RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True, add_negative_train_samples=True),
    ])
    train_data, val_data, test_data = transform(data)
    data.train_edge_index = train_data.edge_index
    data.train_edge_label_index = train_data.edge_label_index
    data.train_edge_label = train_data.edge_label

    data.val_edge_index = val_data.edge_index
    data.val_edge_label_index = val_data.edge_label_index
    data.val_edge_label = val_data.edge_label

    data.test_edge_index = test_data.edge_index
    data.test_edge_label_index = test_data.edge_label_index
    data.test_edge_label = test_data.edge_label
    return data, text, list(class_map.keys())
