

import torch
import pandas as pd
import numpy as np
from torch_geometric.datasets import Planetoid, Coauthor, Amazon, CoraFull, CitationFull
from tqdm import tqdm
from utils import evaluate4, get_data_mask, evaluate_inductive


from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import Ridge
from torch_geometric.utils import one_hot
import scipy.sparse as sp

# inductive setting

def get_top_d(G, d):
    pca = TruncatedSVD(n_components=d)
    X = pca.fit_transform(G)
    explained_variance_ratio = pca.explained_variance_ratio_
    # normalized feature
    X = X / np.sqrt(np.square(X).sum(axis = 0))
    X = X * (explained_variance_ratio / explained_variance_ratio.max()).reshape(1,-1)
    # replace Nan with 0
    X = np.nan_to_num(X)
    return X



method = 'transductive'
# method = 'inductive'

data_dict = {
    'Cora': Planetoid(root = "data/Planetoid", name="Cora"),
    # 'CiteSeer': Planetoid(root = "data/Planetoid", name="CiteSeer"),
    # 'PubMed': Planetoid(root = "data/Planetoid", name="PubMed"),
    # 'Computers': Amazon(root = 'data/Amazon', name = 'Computers'),
    # 'Photo': Amazon(root = 'data/Amazon', name = 'Photo'),
    # 'CoraFull': CoraFull('data/CoraFull'),
    # 'Physics': Coauthor('data/Coauthor', name = 'Physics'),
    # 'CS': Coauthor('data/Coauthor', name = 'CS'),
    # 'DBLP': CitationFull('data/CitationFull', name = 'DBLP' )
}

sweep_dict = {
    'd': [32,64,128,256,512,1024],
    # 'test_portion': [0.01, 0.05,0.10, 0.20, 0.50],
    'test_portion': [0.01],
    'beta': [ 1e3, 1e2, 1e1, 1e0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
    'random_seed': [i for i in range(10)]
}
for test_portion in sweep_dict['test_portion']:
    results = []
    for d in sweep_dict['d']:
        print('d = ', d)
        for data_name, dataset in data_dict.items():
            for random_seed in tqdm(sweep_dict['random_seed']):
                data = dataset[0]

                train_labeled_mask, train_unlabeled_mask, val_mask, test_mask = get_data_mask(data = data, test_portion = test_portion, random_seed = random_seed, data_per_class = 20)
                mask_list = [train_labeled_mask, train_unlabeled_mask, val_mask, test_mask]


                # get the base kernel
                row, col = data.edge_index
                edge_weight = np.ones(data.edge_index.shape[1])
                n = data.x.shape[0]
                G = sp.csc_matrix((edge_weight, data.edge_index), shape=(n,n))


                if method == 'inductive':
                    mask = train_unlabeled_mask
                elif method == 'transductive':
                    mask = train_unlabeled_mask + val_mask + test_mask

                deg = G[mask + train_labeled_mask, :].sum(axis = 0)
                deg_inv_sqrt = torch.tensor(deg).pow(-0.5).reshape(-1)
                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
                norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
                G = sp.csc_matrix((norm, data.edge_index), shape=(n,n))

                try:
                    G_m = G[mask, :][:, mask]
                    v_m = get_top_d(G_m, d)
                    X = G[:, mask] @ v_m
                    y = one_hot(data.y.view(-1))
                    for beta in  sweep_dict['beta']:
                        n_labeled = train_labeled_mask.sum().item()
                        reg = Ridge(alpha = n_labeled*beta).fit(X[train_labeled_mask], y[train_labeled_mask])
                        out = reg.predict(X)
                        pred = torch.tensor(out.argmax(axis=-1).squeeze())
                        result = evaluate_inductive(pred, data, mask_list)
                        results.append({**result, 'data':data_name, 'method': 'kernel_pca_' + str(method),'beta': beta, 'random_seed': random_seed, 'd': d})
                        
                except:    
                    pass

    df = pd.DataFrame(results)
    df.to_csv('results/kernel_pca_' + str(method) + '_test_portion_'+ str(test_portion)+ '.csv', index=False)
