
import pandas as pd
import numpy as np
from torch_geometric.datasets import Planetoid, Coauthor, Amazon, CoraFull, CitationFull
from tqdm import tqdm
from utils import get_mask_wrt_train, get_data_mask, evaluate_inductive
from torch_geometric.utils import one_hot
import scipy.sparse as sp
import torch
from stkr_inductive import STKR_inductive



sweep_dict = {'num_layers': [1,2,4,8, 16,32], 'alpha': [0.01,0.1, 0.3, 0.7,0.8,0.85, 0.9,0.99,0.999],
'eta': [0],
 'beta': [  1e3, 1e2, 1e1, 1e0, 1e-1, 1e-2, 1e-3],
'random_seed': [i for i in range(10)],
's':[[0,1],[0,0,1],[0,0,0,0,1],[0,0,0,0,0,0,1],[0,0,0,0,0,0,0,0,1]],
'test_portion': [0.01],
'gamma': [1e-2]
}
max_iter = 200

data_dict = {
    'Cora': Planetoid(root = "data/Planetoid", name="Cora"),
    # 'CiteSeer': Planetoid(root = "data/Planetoid", name="CiteSeer"),
    # 'PubMed': Planetoid(root = "data/Planetoid", name="PubMed"),
    # 'Computers': Amazon(root = 'data/Amazon', name = 'Computers'),
    # 'Photo': Amazon(root = 'data/Amazon', name = 'Photo'),
    # 'CoraFull': CoraFull('data/CoraFull'),
    # 'Physics': Coauthor('data/Coauthor', name = 'Physics'),
    # 'CS': Coauthor('data/Coauthor', name = 'CS'),
    # 'DBLP': CitationFull('data/CitationFull', name = 'DBLP' )
}

for gamma in sweep_dict['gamma']:
    for test_portion in sweep_dict['test_portion']:
        results = []
        for name, dataset in data_dict.items():
            print('dataset = ', name)
            data = dataset[0]
            


            # create a random mask
            for random_seed in tqdm(sweep_dict['random_seed']):
                
                train_labeled_mask, train_unlabeled_mask, val_mask, test_mask = get_data_mask(data = data, test_portion = test_portion, random_seed = random_seed, data_per_class = 20)
                mask_list = [train_labeled_mask, train_unlabeled_mask, val_mask, test_mask]

               
                # print('--- STKR inductive ---')
                # calculate the kernel matrix
                row, col = data.edge_index
                edge_weight = np.ones(data.edge_index.shape[1])
                n = data.x.shape[0]
                G = sp.csc_matrix((edge_weight, data.edge_index), shape=(n,n))
                deg = G[train_labeled_mask + train_unlabeled_mask, :].sum(axis = 0)
                deg_inv_sqrt = torch.tensor(deg).pow(-0.5).reshape(-1)
                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
                norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
                G = sp.csc_matrix((norm, data.edge_index), shape=(n,n))
                G_K = G[train_labeled_mask + train_unlabeled_mask, :][:, train_labeled_mask + train_unlabeled_mask]
                y = one_hot(data.y)[train_labeled_mask]
                train_labeled_mask_GK , train_unlabeled_mask_GK = get_mask_wrt_train(train_labeled_mask, train_unlabeled_mask)

                
                for beta in sweep_dict['beta']:
                    for s in sweep_dict['s']:
                        # print('beta = ', beta, 's = ', s)
                        
                        model = STKR_inductive()
                        try:
                            model.fit(G_K, y, train_labeled_mask_GK, s = s, normalize = False, beta = beta, gamma = gamma, max_iter = max_iter, target_eps = 1e-3)
                            out = model.predict(G[:, train_labeled_mask + train_unlabeled_mask], train_labeled_mask_GK)
                            pred = torch.tensor(out).argmax(dim=-1, keepdim=True).squeeze()
                            result = evaluate_inductive(pred, data, mask_list)
                            results.append({**result, 'data':name, 'method': 'STKR_inductive'+'_s_'+str(s), 'beta' : beta, 'random_seed': random_seed, 'eps': model.get_eps(), 'gamma': gamma, 'max_iter': max_iter})
                        except:
                            pass

            df = pd.DataFrame(results)
            df.to_csv('results/inductive_STKR_test_portion_'+str(test_portion)+'_max_iter_' + str(max_iter)+ '_gamma_' + str(gamma)+ '.csv', index=False)
