import sys
sys.path.append('../')

import os
import os.path as osp
import argparse
import numpy as np
import random
import torch
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint
from torch_geometric.utils import remove_self_loops
from utils import Logger, add_parser, load_model, load_dataset, load_splits, \
                  load_params, load_attack, saved_file_name


parser = argparse.ArgumentParser()
parser = add_parser(parser)
args = parser.parse_args()
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

before_acc_list = []
evasion_acc_list = []
poison_acc_list = []
time_list = []

random.seed(529)
seeds = [random.randint(0, 2**32-1) for _ in range(args.runs)]

dataset = load_dataset(args)
data = dataset[0]

# data preprocessing
if args.dataset in ['chameleon', 'squirrel']:
    data.edge_index, _ = remove_self_loops(data.edge_index)
if args.dataset in ['ogbn-arxiv']:
    data.y = data.y.squeeze(dim=-1)

num_features = data.x.size(-1)
num_classes = data.y.max().item() + 1
params = load_params(args)

for i, seed in enumerate(seeds):
    ### Fix Seed ###
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed_all(seed)

    ### Load Splits
    splits = load_splits(args, dataset, seeds[(i) % args.runs])
    victim_model = load_model(args.model, params, num_features, num_classes)

    ### Before Attack ###
    trainer_before = Trainer(victim_model, device=device)
    trainer_before.reset_optimizer(lr=params['lr'], weight_decay=params['wd'])
    ckp = ModelCheckpoint(f'{args.dataset}_model.pth', monitor='val_acc')
    trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes),
                    callbacks=[ckp])
    logs = trainer_before.evaluate(data, splits.test_nodes)
    before_acc = logs.acc
    print(before_acc)

    ### Load Attack ###
    # If attacked graph exist, load it
    atk_data_file = saved_file_name(args, seed)
    folder_path = osp.join('../atk_data', f'{args.dataset}_{args.ptb_rate*100}', args.attack)
    if not osp.exists(folder_path):
        os.makedirs(folder_path, exist_ok=True)
    atk_data_name = saved_file_name(args, i)
    filename = osp.join(folder_path, atk_data_name)
    if not osp.exists(filename):
        attacker, atk_time = load_attack(args, data, splits, trainer_before.model, device, seed=seed)
        torch.save(attacker.data().edge_index, filename) # save atk_data
    else:
        ptb_edge_index = torch.load(filename).to(device)
        attacker, atk_time = load_attack(args, data, splits, trainer_before.model, device, ptb_edges=ptb_edge_index, seed=seed)

    ### Fix Seed ###
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    ### Evasion ###
    logs = trainer_before.evaluate(attacker.data(), splits.test_nodes)
    evasion_acc = logs.acc
    print(evasion_acc)

    ### Poison ###
    victim_model = load_model(args.model, params, num_features, num_classes)
    trainer_after = Trainer(victim_model, device=device)
    trainer_after.reset_optimizer(lr=params['lr'], weight_decay=params['wd'])
    ckp = ModelCheckpoint(f'{args.dataset}_model_after.pth', monitor='val_acc')
    trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes),
                    callbacks=[ckp])
    logs = trainer_after.evaluate(attacker.data(), splits.test_nodes)
    poison_acc = logs.acc
    print(poison_acc)

    before_acc_list.append(before_acc)
    evasion_acc_list.append(evasion_acc)
    poison_acc_list.append(poison_acc)
    time_list.append(atk_time)

logger = Logger(log_dir=f'./logs/{args.model}', args=args)
logger.write_log(args, [before_acc_list, evasion_acc_list, poison_acc_list, time_list])