
import pandas as pd
import numpy as np
from torch_geometric.datasets import Planetoid, Coauthor, Amazon, CoraFull, CitationFull
from tqdm import tqdm
from utils import  get_mask_wrt_train, get_data_mask, evaluate_inductive
from torch_geometric.utils import one_hot
import scipy.sparse as sp
import torch
from stkr_inductive import  STKR_inductive_sinv


eta_list = [1e-3, 1e-2, 0.1, 1]
sweep_dict = {'num_layers': [1,2,4,8, 16,32], 'alpha': [0.01,0.1, 0.3, 0.7,0.8,0.85, 0.9,0.99,0.999],
'eta': [0],
'beta': [  1e3, 1e2, 1e1, 1e0, 1e-1, 1e-2, 1e-3],
'random_seed': [i for i in range(10)],
's_inv': [[1, -i] for i in eta_list],
'r':[1],
'test_portion': [0.01],
'gamma': [1e-6]
}
max_iter = 200

data_dict = {
    'Cora': Planetoid(root = "data/Planetoid", name="Cora"),
    # 'CiteSeer': Planetoid(root = "data/Planetoid", name="CiteSeer"),
    # 'PubMed': Planetoid(root = "data/Planetoid", name="PubMed"),
    # 'Computers': Amazon(root = 'data/Amazon', name = 'Computers'),
    # 'Photo': Amazon(root = 'data/Amazon', name = 'Photo'),
    # 'CoraFull': CoraFull('data/CoraFull'),
    # 'Physics': Coauthor('data/Coauthor', name = 'Physics'),
    # 'CS': Coauthor('data/Coauthor', name = 'CS'),
    # 'DBLP': CitationFull('data/CitationFull', name = 'DBLP' )
}
for gamma in sweep_dict['gamma']:
    for test_portion in sweep_dict['test_portion']:
        results = []
        for name, dataset in data_dict.items():
            print('dataset = ', name)
            data = dataset[0]

            # create a random mask
            for random_seed in tqdm(sweep_dict['random_seed']):
                # print('random seed = ', random_seed)
                train_labeled_mask, train_unlabeled_mask, val_mask, test_mask = get_data_mask(data = data, test_portion = test_portion, random_seed = random_seed, data_per_class = 20)
                mask_list = [train_labeled_mask, train_unlabeled_mask, val_mask, test_mask]

               
                # print('--- STKR inductive ---')
                # calculate the kernel matrix
                row, col = data.edge_index
                edge_weight = np.ones(data.edge_index.shape[1])
                n = data.x.shape[0]
                G = sp.csc_matrix((edge_weight, data.edge_index), shape=(n,n))
                deg = G[train_labeled_mask + train_unlabeled_mask, :].sum(axis = 0)
                deg_inv_sqrt = torch.tensor(deg).pow(-0.5).reshape(-1)
                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
                norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
                G = sp.csc_matrix((norm, data.edge_index), shape=(n,n))
                G_K = G[train_labeled_mask + train_unlabeled_mask, :][:, train_labeled_mask + train_unlabeled_mask]
                y = one_hot(data.y)[train_labeled_mask]
                train_labeled_mask_GK , train_unlabeled_mask_GK = get_mask_wrt_train(train_labeled_mask, train_unlabeled_mask)


                
                for beta in sweep_dict['beta']:
                    for s_inv in sweep_dict['s_inv']:
                        for r in sweep_dict['r']:
                            model = STKR_inductive_sinv()

                            try:
                                model.fit(G_K, y, train_labeled_mask_GK, s_inv = s_inv, r = r,  normalize = False, beta = beta, gamma = gamma, max_iter = max_iter, target_eps = 1e-3)
                                out = model.predict(G[:, train_labeled_mask + train_unlabeled_mask], train_labeled_mask_GK)
                                pred = torch.tensor(out).argmax(dim=-1, keepdim=True).squeeze()
                                result = evaluate_inductive(pred, data, mask_list)
                                results.append({**result, 'data':name, 'method': 'STKR_inductive'+'_sinv_'+str(s_inv)+'_r_'+str(r), 'beta' : beta, 'random_seed': random_seed, 'eps': model.get_eps(), 'gamma': gamma, 'max_iter': max_iter, 'eta': -1*s_inv[1]})
                            except:
                                pass

            df = pd.DataFrame(results)
            df.to_csv('results/inductive_STKR_sinv_test_portion_'+str(test_portion)+'_max_iter_' + str(max_iter)+ '_gamma_' + str(gamma) + '.csv', index=False)





