import argparse
from utils import *
import warnings
from train_test import ARCDetector
import numpy as np
import pandas as pd


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument('--trials', type=int, default=5)
parser.add_argument('--model', type=str, default='OWLEYE')
parser.add_argument('--json_dir', type=str, default='./best_param_pt')
args = parser.parse_known_args()[0]

datasets_train = ['pubmed', 'citeseer', 'questions', 'YelpChi']
datasets_test = ['cora', 'Flickr', 'ACM', 'BlogCatalog', 'Facebook', 'weibo', 'Reddit', 'Amazon']

model = args.model
model_result = {'name': model}
print('Training on {} datasets:'.format(len(datasets_train)), datasets_train)
print('Test on {} datasets:'.format(len(datasets_test)), datasets_test)

train_config = {
    'device': 'cuda',
    'epochs': 100,
    'metric': 'AUPRC',
    'testdsets': datasets_test,
}
dims = 64
data_train = [Dataset(dims, name) for name in datasets_train]
data_test = [Dataset(dims, name) for name in datasets_test]
model_config = read_json('OWLEYE', args.json_dir)

if model_config is None:
    model_config = {'model': model, 'lr': 1e-4, 'drop_rate': 0., 'h_feats': 512, 'topk': 10, 'num_hops': 4,
                    'weight_decay': 5e-5, 'in_feats': dims}
    print('use default model config')
else:
    print('use saved best model config')
    print(model_config)

for tr_data in data_train:
    tr_data.propagated(model_config['num_hops'])
for te_data in data_test:
    te_data.propagated(model_config['num_hops'])
data_train, data_test = normalization(data_train, data_test, model_config['tau'])
for tr_data in data_train:
    tr_data.propagated(model_config['num_hops'])
for te_data in data_test:
    te_data.propagated(model_config['num_hops'])
model_config['beta'] = 0.01
model_config['model'] = model
model_config['in_feats'] = dims

# Initialize dictionaries to store scores for each test dataset
auc_dict = {}
pre_dict = {}
rec_dict = {}
for t in range(args.trials):
    print("Model {}, Trial {}".format(model, t))
    set_seed(t)
    train_config['seed'] = t
    for te_data in data_test:
        te_data.few_shot(10)
    data = {'train': data_train, 'test': data_test}
    detector = ARCDetector(train_config, model_config, data)
    test_score_list, similarity = detector.train()
    import pickle
    with open('attetnion_map.pkl', 'wb') as f:
        pickle.dump(similarity, f)
    # Aggregate scores for each test dataset
    for test_data_name, test_score in test_score_list.items():
        if test_data_name not in auc_dict:
            auc_dict[test_data_name] = []
            pre_dict[test_data_name] = []
            rec_dict[test_data_name] = []
        auc_dict[test_data_name].append(test_score['AUROC'])
        pre_dict[test_data_name].append(test_score['AUPRC'])
        print(f'Test on {test_data_name}, AUC is {auc_dict[test_data_name]}')

# Calculate mean and standard deviation for each test dataset
auc_mean_dict, auc_std_dict, pre_mean_dict, pre_std_dict = {}, {}, {}, {}

for test_data_name in auc_dict:
    auc_mean_dict[test_data_name] = np.mean(auc_dict[test_data_name])
    auc_std_dict[test_data_name] = np.std(auc_dict[test_data_name])
    pre_mean_dict[test_data_name] = np.mean(pre_dict[test_data_name])
    pre_std_dict[test_data_name] = np.std(pre_dict[test_data_name])
# Output the results for each test dataset
results = []
avg_auroc = []
avg_pre = []
for test_data_name in auc_mean_dict:
    result = [test_data_name]
    result.append(auc_mean_dict[test_data_name])
    result.append(auc_std_dict[test_data_name])
    result.append(pre_mean_dict[test_data_name])
    result.append(pre_std_dict[test_data_name])
    results.append(result)
    avg_auroc.append(auc_mean_dict[test_data_name])
    avg_pre.append(pre_mean_dict[test_data_name])
    str_result = 'AUROC:{:.4f}+-{:.4f}, AUPRC:{:.4f}+-{:.4f}'.format(
        auc_mean_dict[test_data_name],
        auc_std_dict[test_data_name],
        pre_mean_dict[test_data_name],
        pre_std_dict[test_data_name])
    print('-' * 50 + test_data_name + '-' * 50)
    print('str_result', str_result)

results = pd.DataFrame(results)
header = ['dataset', 'auroc_mean', 'auroc_std', 'auprc_mean', 'auprc_std']
if not os.path.exists('results'):
    os.mkdir('results')
results.to_csv('results/{}_auc_mean_{}_{}_{}_{}.csv'.format(model, model_config['mask_ratio'],
            model_config['n_support'], model_config['temperature'], model_config['tau']), header=header, index=False)
