import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, to_scipy_sparse_matrix
import ot
from dataset import get_dataset, get_dataset2, get_dataset3, get_dataset_condensed
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import to_undirected, to_scipy_sparse_matrix

def data_sgc(ori_data):
    ori_data = ori_data.cpu()
    undirected_edge_index = to_undirected(ori_data.edge_index)
    ori_data.edge_index = undirected_edge_index
    edge_index, edge_weight = gcn_norm(ori_data.edge_index, ori_data.edge_weight, ori_data.x.shape[0])
    A = to_scipy_sparse_matrix(edge_index, edge_weight)
    A = A.todense()
    X2 = ori_data.x
    A = torch.tensor(A)
    AX = torch.mm(A, X2)
    AAX = torch.mm(A, AX)
    dict2= {}
    dict2['label'] = ori_data.y
    dict2['emb'] = AAX
    dict2['train_mask'] = ori_data.train_mask
    dict2['test_mask'] = ori_data.test_mask
    dict2['val_mask'] = ori_data.val_mask
    return dict2


def generate_new_masks(data, train_ratio=0.2):
    num_nodes = data.num_nodes
    indices = torch.randperm(num_nodes)
    train_size = int(num_nodes * train_ratio)
    val_size = (num_nodes - train_size) // 2
    test_size = num_nodes - train_size - val_size

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size:train_size + val_size]] = True
    test_mask[indices[train_size + val_size:]] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


def generate_condensed_z_y(data, P2):
    P = P2.detach()
    P_one_hot = torch.zeros_like(P)
    P_one_hot[torch.arange(P.shape[0]), P.argmax(dim=1)] = 1.0
    train_labels = data.y[data.train_mask]
    one_hot_train_labels = F.one_hot(train_labels, num_classes=data.num_classes).float().to(P.device)
    one_hot_labels = torch.zeros(data.num_nodes, data.num_classes).to(P.device)
    one_hot_labels[data.train_mask] = one_hot_train_labels
    s_emb_label = torch.mm(P_one_hot.t(), one_hot_labels)
    s_emb_label = F.normalize(s_emb_label.clamp(min=0), p=1, dim=1)
    return s_emb_label



def generate_condensed_z_y_regre(data, P2):
    P = P2.detach()
    P_one_hot = torch.zeros_like(P)
    P_one_hot[torch.arange(P.shape[0]), P.argmax(dim=1)] = 1.0
    train_labels = data.y#[data.train_mask]

    data.y_mean = torch.mean(data.y)
    data.y_std = torch.std(data.y)
    data.y_regre_std = (data.y - data.y_mean) / data.y_std
    train_labels = data.y_regre_std[data.train_mask]

    train_P = P_one_hot[data.train_mask]
    train_P = F.normalize(train_P, p=1, dim=0)
    #column_sums = train_P.sum(dim=0)

    s_emb_label = torch.mm(train_labels.unsqueeze(0), train_P)
    s_emb_label = s_emb_label.squeeze(0)
    print(s_emb_label)
    return s_emb_label






def task_transfer(ori_data, sys_data):

    import ot


    ori_task = sys_data.y

    target_task = ori_data.y


    ori_data = ori_data.cpu()
    undirected_edge_index = to_undirected(ori_data.edge_index)
    ori_data.edge_index = undirected_edge_index
    condensed_data = sys_data.cpu()

    edge_index, edge_weight = gcn_norm(ori_data.edge_index, ori_data.edge_weight, ori_data.x.shape[0])
    A = to_scipy_sparse_matrix(edge_index, edge_weight)
    A = A.todense()

    edge_index, edge_weight = gcn_norm(condensed_data.edge_index, condensed_data.edge_weight, condensed_data.x.shape[0])
    A_c = to_scipy_sparse_matrix(edge_index, edge_weight)
    A_c = A_c.todense()

    X2 = ori_data.x
    X_c2 = condensed_data.x
    Mp = ot.dist(X2, X_c2, metric='euclidean')
    C1 = torch.squeeze(torch.tensor(A))
    C2 = torch.squeeze(torch.tensor(A_c))
    h1 = ot.unif(X2.shape[0], type_as=X2)
    alpha = 0.5

    h1 = ot.unif(X2.shape[0], type_as=X2)
    h2 = ot.unif(X_c2.shape[0], type_as=X_c2)



    A = torch.tensor(A)
    A_c = torch.tensor(A_c)
    AX = torch.mm(A,X2)
    AX_c = torch.mm(A_c,X_c2)
    cost_emb = ot.dist(AX, AX_c, metric='euclidean')
    P = ot.emd(h1, h2, cost_emb, numItermax=500000)



    row_sums = torch.sum(P, dim=1, keepdim=True)
    P_norm = P / row_sums
    P_one_hot = torch.zeros_like(P_norm)
    P_one_hot[torch.arange(P_norm.shape[0]), P_norm.argmax(dim=1)] = 1.0

    # class to class
    #s_emb_label = generate_condensed_z_y(ori_data, P_one_hot)

    # class to regre
    s_emb_label = generate_condensed_z_y_regre(ori_data, P_one_hot)
    sys_data.y = s_emb_label

    return sys_data



import pickle
dataset = 'pubmed_0.25_4'
# ori_data = get_dataset(dataset, normalize_features=False, transform=None)
ori_data = get_dataset3(dataset)
dic = data_sgc(ori_data)
with open(f"{dataset}_sgc.pkl", "wb") as f:
    pickle.dump(dic, f)