import concurrent
import os
import time

import torch.optim as optim
import pickle
import numpy as np
from torch.optim.lr_scheduler import StepLR
import torch
from collections import defaultdict
from torch import nn
from torch_geometric.data import DataLoader

from utils import *
from model import *
import torch
import random

def evaluate(model, dataloader, device):
    model.eval()
    total_absolute_error = 0.0
    total_relative_error = 0.0
    num_samples = 0

    for data in dataloader:
        data = data.to(device)
        local_embedding = data.x  # Local embedding from GCN path
        global_embedding = data.dists_avg  # Global embedding from PGNN path
        edge_label_index = data.edge_label_index
        edge_index = data.edge_index

        start_infer = time.time()
        predictions = model(local_embedding, global_embedding,edge_index,edge_label_index).squeeze()
        infer_time = time.time() - start_infer
        print(f" {infer_time:.6f} ")



        labels = data.y
        print(labels.shape)

        absolute_error = torch.mean(torch.abs(predictions -  labels))
        relative_error = torch.mean(torch.abs((predictions - labels)/(labels + 1e-8)))

    return absolute_error,relative_error



if __name__ == "__main__":
    EMBEDDING_DIM = 80
    NUM_EPOCHS = 1000
    BATCH_SIZE = 64
    LEARNING_RATE = 0.001
    EARLY_STOPPING_PATIENCE = 100

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    with open('dataset/Cora/cora.pkl', 'rb') as f:
        G = pickle.load(f)

    embedding_file = 'dataset/Cora/node_embeddings.pkl'
    pivot_file = 'dataset/Cora/pivot_nodes.pkl'
    test_embedding_file = 'dataset/Cora/node_embeddings_test.pkl'
    test_pivot_file = 'dataset/Cora/pivot_nodes_test.pkl'

    if os.path.exists(embedding_file) and os.path.exists(pivot_file):
        with open(embedding_file, 'rb') as f:
            node_embeddings = pickle.load(f)
            print(node_embeddings)
        with open(pivot_file, 'rb') as f:
            pivot_nodes = pickle.load(f)
            print(pivot_nodes)
    else:
        print("...")

        num_pivots = EMBEDDING_DIM
        pivot_nodes = np.random.choice(len(G.nodes()), num_pivots, replace=False)

        shortest_paths = compute_shortest_paths_from_pivots(G, pivot_nodes)

        distances = np.zeros((len(G.nodes()), num_pivots))
        for i, node in enumerate(G.nodes()):
            for j, pivot in enumerate(pivot_nodes):
                distances[i, j] = shortest_paths[node][pivot]

        node_embeddings = torch.tensor(distances, dtype=torch.float)

        with open(embedding_file, 'wb') as f:
            pickle.dump(node_embeddings, f)
        with open(pivot_file, 'wb') as f:
            pickle.dump(pivot_nodes, f)

    if os.path.exists(test_embedding_file) and os.path.exists(test_pivot_file):
        with open(test_embedding_file, 'rb') as f:
            node_embeddings_test = pickle.load(f)
            print(node_embeddings_test)
        with open(test_pivot_file, 'rb') as f:
            pivot_nodes_test = pickle.load(f)
            print(pivot_nodes_test)
    else:
        print("...")

        candidates = [node for node in G.nodes() if node not in pivot_nodes]
        pivot_nodes_test = random.sample(candidates, 40)

        shortest_paths_test = compute_shortest_paths_from_pivots(G, pivot_nodes_test)

        test_distances = np.zeros((len(G.nodes()), len(pivot_nodes_test)))
        for i, node in enumerate(G.nodes()):
            for j, pivot in enumerate(pivot_nodes_test):
                test_distances[i, j] = shortest_paths_test[node][pivot]

        node_embeddings_test = torch.tensor(test_distances, dtype=torch.float)

        with open(test_embedding_file, 'wb') as f:
            pickle.dump(node_embeddings_test, f)
        with open(test_pivot_file, 'wb') as f:
            pickle.dump(pivot_nodes_test, f)

    candidates = [node for node in G.nodes() if node not in pivot_nodes]
    train_samples = [(pivot, node) for pivot in pivot_nodes for node in G.nodes() if node != pivot]
    test_samples = [(pivot, node) for pivot in pivot_nodes_test for node in G.nodes() if node != pivot]

    labels = compute_labels(node_embeddings, train_samples, pivot_nodes)

    data = create_data_from_graph(G, node_embeddings)
    data.y = labels
    data.edge_label_index = torch.tensor(train_samples, dtype=torch.long).t().contiguous()

    edge_index = data.edge_index
    num_nodes = len(G.nodes())
    data.dists = precompute_dist_data(edge_index, num_nodes)


    layer_num = 1
    anchor_num = 32
    anchor_size_num = 4
    preselect_anchor(data, layer_num=layer_num, anchor_num=anchor_num, anchor_size_num=anchor_size_num, device=device)

    dataloader = DataLoader([data], batch_size=BATCH_SIZE, shuffle=True)

    local_embedding_dim = EMBEDDING_DIM
    global_embedding_dim = data.dists_avg.shape[1]
    model = MV_GCN_Fusion_Model(in_channels=EMBEDDING_DIM,layer_num=layer_num).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=50, gamma=0.9)

    print("...")
    best_loss = 100
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0
        for data in dataloader:
            data = data.to(device)
            optimizer.zero_grad()
            local_embedding = data.x
            global_embedding = data.dists_avg
            edge_label_index = data.edge_label_index
            edge_index = data.edge_index

            predictions = model(local_embedding, global_embedding,edge_index,edge_label_index).squeeze()

            loss = torch.mean(torch.abs(predictions - data.y))
            loss.backward()
            optimizer.step()

        avg_loss = loss
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            best_model_state_dict = model.state_dict()
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            break

        scheduler.step()

    torch.save(model.state_dict(), 'best_distance_model.pth')
    print(f" {best_loss:.4f}")

    print(node_embeddings_test.shape)
    print(pivot_nodes_test)
    print(len(test_samples))
    labels = compute_labels(node_embeddings_test, test_samples, pivot_nodes_test)
    data.y = labels
    data.edge_label_index = torch.tensor(test_samples, dtype=torch.long).t().contiguous()
    dataloader = DataLoader([data], batch_size=BATCH_SIZE, shuffle=True)
    model.load_state_dict(best_model_state_dict)
    avg_mae, avg_mre = evaluate(model, dataloader, device)
    print(f"MAE: {avg_mae:.4f},MRE: {avg_mre:.4f}")
