
import pandas as pd
import numpy as np
from torch_geometric.datasets import Planetoid, Coauthor, Amazon, CoraFull, CitationFull
from torch_geometric.nn import LabelPropagation
from tqdm import tqdm
from utils import get_data_mask, evaluate_inductive
from stkr import STKR, STKR_inv


data_dict = {
    'Cora': Planetoid(root = "data/Planetoid", name="Cora"),
    # 'CiteSeer': Planetoid(root = "data/Planetoid", name="CiteSeer"),
    # 'PubMed': Planetoid(root = "data/Planetoid", name="PubMed"),
    # 'Computers': Amazon(root = 'data/Amazon', name = 'Computers'),
    # 'Photo': Amazon(root = 'data/Amazon', name = 'Photo'),
    # 'CoraFull': CoraFull('data/CoraFull'),
    # 'Physics': Coauthor('data/Coauthor', name = 'Physics'),
    # 'CS': Coauthor('data/Coauthor', name = 'CS'),
    # 'DBLP': CitationFull('data/CitationFull', name = 'DBLP' )
}


sweep_dict = {'num_layers': [1,2,4,8, 16,32], 'alpha': [0.7,0.8,0.85, 0.9,0.99,0.999],
'eta': [0],
'beta': [  1e3, 1e2, 1e1, 1e0,  1e-1, 1e-2,1e-3, ],
'random_seed': [i for i in range(10)],
's':[[0,1],[0,0,1],[0,0,0,0,1],[0,0,0,0,0,0,1],[0,0,0,0,0,0,0,0,1]],
'test_portion': [0.01]
}

for test_portion in sweep_dict['test_portion']:
    results = []
    for name, dataset in data_dict.items():
        print('dataset = ', name)
        data = dataset[0]
        


        # create a random mask
        for random_seed in (sweep_dict['random_seed']):
            print('random seed = ', random_seed)

            # [train_labeled_mask, train_unlabeled_mask, val_mask, test_mask]
            train_labeled_mask, train_unlabeled_mask, val_mask, test_mask = get_data_mask(data = data, test_portion = test_portion, random_seed = random_seed, data_per_class = 20)
            mask_list = [train_labeled_mask, train_unlabeled_mask, val_mask, test_mask]

            print('--- Label Propagation ---')
            for num_layers in tqdm(sweep_dict['num_layers']):
                for alpha in sweep_dict['alpha']:
                    model = LabelPropagation(num_layers=num_layers, alpha=alpha)
                    out = model(data.y, data.edge_index, mask=train_labeled_mask)
                    y_pred = out.argmax(dim=-1, keepdim=True).squeeze()
                    result = evaluate_inductive(y_pred, data, mask_list)
                    results.append({**result, 'method': 'LPA', 'num_layers': num_layers, 'alpha':alpha, 'beta' : 0, 'random_seed': random_seed, 'data': name})
                

            print('--- STKR Lap ---')
            for num_layers in tqdm(sweep_dict['num_layers']):
                for alpha in sweep_dict['alpha']:
                    for eta in sweep_dict['eta']:
                        model = STKR_inv(num_layers=num_layers, xi= [1+eta, -1], r=1, alpha = alpha)
                        for beta in sweep_dict['beta']:
                            out = model(data.y, data.edge_index, mask=train_labeled_mask, normalize = False, beta = beta)
                            y_pred = out.argmax(dim=-1, keepdim=True).squeeze()
                            result = evaluate_inductive(y_pred, data, mask_list)
                            results.append({**result, 'method': 'SP-Lap', 'num_layers': num_layers, 'alpha':alpha, 'beta' : beta, 'random_seed': random_seed, 'data': name})


            print('--- STKR poly---')
            for num_layers in tqdm(sweep_dict['num_layers']):
                for alpha in sweep_dict['alpha']:
                        for s in sweep_dict['s']:
                            model = STKR(num_layers=num_layers, s = s, alpha=alpha)
                            for beta in sweep_dict['beta']:
                                out = model(data.y, data.edge_index, mask=train_labeled_mask, normalize = False, beta = beta)
                                y_pred = out.argmax(dim=-1, keepdim=True).squeeze()
                                result = evaluate_inductive(y_pred, data, mask_list)
                                results.append({**result, 'method': 'SP-poly'+'_s_'+str(s), 'num_layers': num_layers, 'alpha':alpha, 'beta' : beta, 'random_seed': random_seed, 'data': name})

    df = pd.DataFrame(results)
    df.to_csv('results/STKR_transductive_p_test_' + str(test_portion) + '.csv')