import argparse

import numpy as np
import torch

import utils

# %%
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cora',
                    help='specify datasets as a string, \'cora\'')
parser.add_argument('--task', type=str, default='node_classification')
parser.add_argument('--model_path', type=str, default='')
parser.add_argument('--epochs', '-e', type=int, default=1000,
                    help='number of epochs for fine-tuning')
parser.add_argument('--save', type=bool, default=False,
                    help='whether to save model after fine-tuning')
parser.add_argument('--supervise', type=str, default=False,
                    help='whether to train a supervised baseline model')
args = parser.parse_args()

if args.task not in ['node_classification', 'link_prediction', 'graph_classification']:
    raise(ValueError('Invalid task type. Allowed values are \'node_classification\',  \'link_prediction\' and '
                     '\'graph_classification\'.'))

# %%
# dataset = args.dataset
task = 'node_classification'
data = utils.create_data_structure(
    datasets=args.dataset,
    ssl=False,
    task=args.task,
    model_type='transformer',
)

if 'hetphl' in args.dataset:
    data.num_node_features+=15
    
# ssl_model = torch.load("./models/29_07_23/expt_0/ssl_transformer_7_datasets_neighbourhood_aggregation.pt")
ssl_model = torch.load(args.model_path)

# %%
acc = []
for count_iter in range(10):
    print("#### iteration no", count_iter )
    downstream_model = utils.create_downstream_model(
        task=args.task,
        configs=ssl_model['configs'],
        num_hops=ssl_model['num_hops'],
        num_features=data.num_node_features,
        num_classes=data.num_classes,
        device=torch.device('cuda'),
        state_dict=ssl_model['model_state_dict'],
    )
    optimizer = torch.optim.Adam(params=downstream_model.predictor.parameters(), lr=1e-2)
    if args.save:
        model_name = args.task + '_' + args.dataset + 'transformer.pt'
        save_path = utils.create_save_path(model_name)
    else:
        save_path = None
    results = downstream_model.train_model(
        dataset=data,
        loss=torch.nn.CrossEntropyLoss(),
        optimizer=optimizer,
        num_epochs=args.epochs,
        freeze_encoder=True,
        save_path= save_path,
        verbose=50,
    )
    acc.append(results[1])
acc = np.array(acc)

# %%
# if args.supervise:
#     acc_1 = []
#     for _ in range(10):
#         downstream_model_1 = utils.create_downstream_model(
#             task=args.task,
#             configs=ssl_model['configs'],
#             num_hops=ssl_model['num_hops'],
#             num_features=data.num_node_features,
#             num_classes=data.num_classes,
#             device=torch.device('cuda'),
#         )
#         optimizer_1 = torch.optim.Adam(downstream_model_1.parameters())
#         if args.save:
#             model_name = args.task + '_' + args.dataset + 'supervised_transformer.pt'
#             save_path_1 = utils.create_save_path(model_name)
#         else:
#             save_path_1 = None

#         results_1 = downstream_model_1.train_model(
#             dataset=data,
#             loss=torch.nn.CrossEntropyLoss(),
#             optimizer=optimizer_1,
#             num_epochs=args.epochs,
#             freeze_encoder=False,
#             save_path=save_path_1,
#             verbose=50,
#         )
#         acc_1.append(results_1[1])
#     acc_1 = np.array(acc_1)

file1 = open(args.dataset+"_1-datasets_pairdis"+".txt","w")
print(f"ssl: {np.mean(acc)} \u00B1 {np.std(acc)}")
file1.writelines(f"ssl: {np.mean(acc)} \u00B1 {np.std(acc)}")

# if args.supervise:
#     print(f"supervised: {np.mean(acc_1)} \u00B1 {np.std(acc_1)}")
#     file1.writelines("\n")
#     file1.writelines(f"supervised: {np.mean(acc_1)} \u00B1 {np.std(acc_1)}")

file1.close()
