import argparse

import numpy as np
import torch

import model_gnn as model_management
import utils

# %%
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='citation',
                    help='specify datasets as a string, \'cora\'')
parser.add_argument('--task', type=str, default='node_classification')
parser.add_argument('--model_path', type=str, default='')
parser.add_argument('--epochs', '-e', type=int, default=1000,
                    help='number of epochs for fine-tuning')
parser.add_argument('--save', type=bool, default=False,
                    help='whether to save model after fine-tuning')
parser.add_argument('--supervise', type=str, default=True,
                    help='whether to train a supervised baseline model')
args = parser.parse_args()

if args.task not in ['node_classification', 'link_prediction', 'graph_classification']:
    raise(ValueError('Invalid task type. Allowed values are \'node_classification\',  \'link_prediction\' and '
                     '\'graph_classification\'.'))

# %%
data = utils.create_data_structure(
    datasets=args.dataset,
    ssl=False,
    task=args.task,
    model_type='gnn'
)

ssl_model = torch.load(args.model_path)

# %%
acc = []
for _ in range(10):
    classifier = model_management.GNNNodeClassifier(stems=ssl_model['configs']['stems'],
                                                backbone=ssl_model['configs']['backbone'],
                                                num_features=data.num_node_features,
                                                num_classes=data.num_classes,
                                                device=torch.device('cuda'),
                                                state_dict=ssl_model['model_state_dict']['encoder'])
    optimizer = torch.optim.Adam(params=classifier.predictor.parameters())
    if args.save:
        model_name = args.task + '_' + args.dataset + '_gnn.pt'
        save_path = utils.create_save_path(model_name)
    else:
        save_path = None

    results = classifier.train_model(
        dataset=data,
        loss=torch.nn.CrossEntropyLoss(),
        optimizer=optimizer,
        num_epochs=args.epochs,
        freeze_encoder=True,
        save_path=save_path,
        verbose=50,
    )
    acc.append(results[1])
acc = np.array(acc)

# %%
if args.supervise:
    acc_1 = []
    for _ in range(10):
        classifier_1 = model_management.GNNNodeClassifier(stems=ssl_model['configs']['stems'],
                                                      backbone=ssl_model['configs']['backbone'],
                                                      num_features=data.num_node_features,
                                                      num_classes=data.num_classes,
                                                      device=torch.device('cuda'))
        optimizer_1 = torch.optim.Adam(classifier_1.parameters())
        if args.save:
            model_name = args.task + '_' + dataset + '_gnn_supervised.pt'
            save_path_1 = utils.create_save_path(model_name)
        else:
            save_path_1 = None
        results_1 = classifier_1.train_model(
            dataset=data,
            loss=torch.nn.CrossEntropyLoss(),
            optimizer=optimizer_1,
            num_epochs=500,
            freeze_encoder=False,
            save_path=save_path_1,
            verbose=50,
        )
        acc_1.append(results_1[1])
    acc_1 = np.array(acc_1)

file1 = open(args.dataset+"_gat_1-datasets"+".txt","w")
print(f"ssl: {np.mean(acc)} \u00B1 {np.std(acc)}")
file1.writelines(f"ssl: {np.mean(acc)} \u00B1 {np.std(acc)}")

if args.supervise:
    print(f"supervised: {np.mean(acc_1)} \u00B1 {np.std(acc_1)}")
    file1.writelines(f"supervised: {np.mean(acc_1)} \u00B1 {np.std(acc_1)}")

file1.close()