import os
import os.path as osp
import torch
import time
import numpy as np
import logging
from greatx.attack.untargeted import GRBCDAttack, PRBCDAttack, Metattack, PGDAttack, \
                                     RandomAttack, DICEAttack, SpecAttack, GFAttack, \
                                     GCosAttack, UntargetedAttacker
from greatx.nn.models import GCN, GAT, GNNGUARD, RobustGCN, MedianGCN
from greatx.datasets import GraphDataset
from greatx.utils import split_nodes, BunchDict
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.transforms import NormalizeFeatures, LargestConnectedComponents, Compose, ToUndirected, ToSparseTensor
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.utils import mask_to_index
from csbm_dataset import dataset_ContextualSBM
from ogb.nodeproppred import PygNodePropPredDataset


def add_parser(parser):
    parser.add_argument('--dataset', type=str, default='cora_ml', help='dataset')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--model', type=str, default='gcn', choices=['gcn', 'gat', 'gnnguard', 'robustgcn', 'mediangcn'])
    parser.add_argument('--attack', type=str, default='random')
    parser.add_argument('--ptb_rate', type=float, default=0.20,  help='pertubation rate')
    parser.add_argument('--block_size', type=int, default=250_000)
    parser.add_argument('--base_lr', type=float, default=1_000)
    parser.add_argument('--threshold', type=float, default=0.5)
    parser.add_argument('--reg', type=float, default=0.0)

    parser.add_argument('--approx', action="store_true")
    parser.add_argument('--k', type=int, default=-1)
    parser.add_argument('--setting', type=str, default='black', choices=['black', 'soft', 'white'])
    parser.add_argument('--loss', type=str, default='she')
    parser.add_argument('--runs', type=int, default=5)
    return parser


class HeteroDataset(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None):
        self.root = root
        self.name = name.lower()
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        return f'{self.name}.npz'

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        data = np.load(self.raw_paths[0])
        x = torch.tensor(data['node_features'])
        y = torch.tensor(data['node_labels'])
        edge_index = torch.tensor(data['edges']).T
        train_mask = torch.tensor(data['train_masks']).to(torch.bool)
        val_mask = torch.tensor(data['val_masks']).to(torch.bool)
        test_mask = torch.tensor(data['test_masks']).to(torch.bool)
        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
                        val_mask=val_mask, test_mask=test_mask)
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])


def load_dataset(args):
    ### Load Dataset ###
    dataset_name = args.dataset
    root = osp.join(osp.dirname(osp.realpath(__file__)), '../', 'data')
    if dataset_name in ['cora_ml', 'citeseer', 'pubmed']:
        transform = Compose([ToUndirected(), LargestConnectedComponents()])
        dataset = GraphDataset(root=root, name=dataset_name, transform=transform)
    elif dataset_name in ['chameleon', 'squirrel']:
        transform = Compose([NormalizeFeatures(), ToUndirected(), LargestConnectedComponents()])
        dataset = WikipediaNetwork(root=root, name=dataset_name, geom_gcn_preprocess=True, transform=transform)
    elif dataset_name in ['roman_empire']:
        transform = Compose([ToUndirected(), LargestConnectedComponents()])
        dataset = HeteroDataset(root=root, name=dataset_name, transform=transform)
    elif dataset_name in ['ogbn-arxiv']:
        dataset = PygNodePropPredDataset(root=root, name=dataset_name, transform=ToUndirected())
    elif dataset_name in ['ogbn-products']:
        dataset = PygNodePropPredDataset(root=root, name=dataset_name, transform=ToSparseTensor())
    elif 'csbm' in dataset_name:
        dataset = dataset_ContextualSBM(root=root, name=dataset_name)
    return dataset


def load_params(args):
    dataset_name = args.dataset
    if dataset_name in ['cora_ml', 'cora', 'citeseer', 'pubmed'] or 'csbm' in dataset_name:
        lr = 0.01
        weight_decay = 5e-4
        hid = 16
        gat_hid = 8
        num_layers = 1
    elif dataset_name in ['chameleon', 'squirrel', 'roman_empire']:
        lr = 0.05
        weight_decay = 0
        hid = 64
        gat_hid = 8
        num_layers = 1
    elif dataset_name in ['ogbn-arxiv', 'ogbn-products']:
        lr = 0.01
        weight_decay = 0
        hid = 256
        gat_hid = 16 
        num_layers = 2
    params = {'lr': lr, 'wd': weight_decay, 'hid': hid, 'gat_hid': gat_hid, \
              'nlayers': num_layers}
    return params


def load_attack(args, data, splits, model, device, ptb_edges=None, seed=None):
    if ptb_edges is not None:
        attacker = UntargetedAttacker(data, device=device)
        attacker.reset()
        attacker.data().edge_index = ptb_edges
        return attacker, 0.0

    start_time = time.time()
    if args.attack == 'random':
        attacker = RandomAttack(data, device=device, seed=seed)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate, threshold=args.threshold)
    elif args.attack == 'dice':
        attacker = DICEAttack(data, device=device, seed=seed)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate, threshold=args.threshold)
    elif args.attack == 'spec':
        attacker = SpecAttack(data, device=device, seed=seed)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate, approx=args.approx)
    elif args.attack == 'gf':
        attacker = GFAttack(data, device=device, seed=seed)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate)
    elif args.attack == 'prbcd':
        attacker = PRBCDAttack(data, device=device, seed=seed)
        attacker.setup_surrogate(model, victim_nodes=splits.test_nodes, ground_truth=True)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate, block_size=args.block_size, lr=args.base_lr)
    elif args.attack == "meta":
        attacker = Metattack(data, device=device, seed=seed)
        attacker.setup_surrogate(model,
                                labeled_nodes=splits.train_nodes,
                                unlabeled_nodes=torch.cat((splits.test_nodes, splits.val_nodes)), lambda_=0.)
        attacker.reset()
        attacker.attack(args.ptb_rate)
    elif args.attack == 'grbcd':
        attacker = GRBCDAttack(data, device=device, seed=seed)
        attacker.setup_surrogate(model, victim_nodes=splits.test_nodes, ground_truth=True)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate, block_size=args.block_size, lr=args.base_lr)
    elif args.attack == "pgd":
        attacker = PGDAttack(data, device=device, seed=seed)
        attacker.setup_surrogate(model, victim_nodes=splits.test_nodes, ground_truth=True)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate) 
    elif args.attack == 'gcos':
        attacker = GCosAttack(data, device=device, seed=seed)
        attacker.setup_surrogate(model)
        attacker.reset()
        attacker.attack(num_budgets=args.ptb_rate, block_size=args.block_size, lr=args.base_lr, 
                        k=args.k, setting=args.setting, 
                        loss=args.loss, approx=args.approx, reg=args.reg)
    end_time = time.time()
    ex_time = end_time - start_time
    return attacker, ex_time


def load_model(model_name, params, num_features, num_classes):
    hid = params['hid']
    num_layers = params['nlayers']
    gat_hid = params['gat_hid']
    ### Load Model ###
    if model_name == 'gcn':
        victim_model = GCN(num_features, num_classes, dropout=0.5, hids=[hid]*num_layers)
    elif model_name == 'gat':
        victim_model = GAT(num_features, num_classes, dropout=0.5, hids=[gat_hid]*num_layers)
    elif model_name == 'gnnguard':
        victim_model = GNNGUARD(num_features, num_classes, dropout=0.5, hids=[hid]*num_layers)
    elif model_name == 'robustgcn':
        victim_model = RobustGCN(num_features, num_classes, dropout=0.5, hids=[hid]*num_layers)
    elif model_name == 'mediangcn':
        victim_model = MedianGCN(num_features, num_classes, dropout=0.5, hids=[hid]*num_layers)
    return victim_model


def load_splits(args, dataset, seed):
    dataset_name = args.dataset
    data = dataset[0]
    y = data.y
    if dataset_name in ['cora_ml', 'citeseer', 'pubmed']:
        splits = split_nodes(y, random_state=seed)
    elif dataset_name in ['chameleon', 'squirrel']:
        splits = split_nodes(y, train=0.6, val=0.2, test=0.2, random_state=seed)
    elif dataset_name in ['roman_empire']:
        splits = split_nodes(y, train=0.5, val=0.25, test=0.25, random_state=seed)
    elif dataset_name in ['ogbn-arxiv', 'ogbn-products']:
    # Load Fix Splits
        split_idx = dataset.get_idx_split()
        splits = BunchDict(
                    dict(train_nodes=split_idx['train'],
                         val_nodes=split_idx['valid'],
                         test_nodes=split_idx['test']))
    elif 'csbm' in dataset_name:
        train_idx = mask_to_index(data['train_mask'])
        val_idx = mask_to_index(data['val_mask'])
        test_idx = mask_to_index(data['test_mask'])
        splits = BunchDict(
                    dict(train_nodes=train_idx,
                         val_nodes=val_idx,
                         test_nodes=test_idx))
    return splits


def saved_file_name(args, seed):
    """
    Seed here is not the real seed we use.
    It is just the number of run for simplicity.
    """
    attack_name = args.attack.lower()
    if attack_name in ['random', 'dice']:
        file_name = f'{attack_name}_{args.threshold*100}_{seed}.pt'
    elif attack_name in ['prbcd', 'grbcd']:
        file_name = f'{attack_name}_{args.block_size}_{args.base_lr}_{seed}.pt'
    elif attack_name in ['meta', 'pgd', 'gf']:
        file_name = f'{attack_name}_{seed}.pt'
    elif attack_name in ['spec']:
        file_name = f'{attack_name}_{args.approx}_{seed}.pt'
    elif attack_name in ['gcos', 'pcos']:
        if args.loss in ['she']:
            k = args.k if args.k != -1 else 0
            loss = f'{args.loss}{k}'
        else:
            loss = f'{args.loss}'
        file_name = f'{attack_name}_{loss}_{args.setting}_{args.reg}_{args.approx}_{args.block_size}_{args.base_lr}_{seed}.pt'
    else:
        assert "Attack Not Exist"

    return file_name

class Logger:
    def __init__(self, log_dir, args):
        self.logger = logging.getLogger('logger')
        self.logger.setLevel(logging.INFO)

        if not osp.exists(f'{log_dir}'):
            os.mkdir(f'{log_dir}')
        file_handler = logging.FileHandler(f'{log_dir}/{args.dataset}_{args.ptb_rate}.txt')
        formatter = logging.Formatter('%(message)s')
        file_handler.setFormatter(formatter)
        self.logger.addHandler(file_handler)

    def write_log(self, args, accuracies):
        header = f'Attack: {args.attack}\n'
        not_saved_keys = ['dataset', 'ptb_rate', 'attack', 'model', 'run', 'device']
        arg_str = ', '.join([f'{k}={v}' for k, v in vars(args).items() if k not in not_saved_keys])
        self.logger.info(f'{"-" * 80}')
        self.logger.info(f'{header}Args: {arg_str}')
        
        title = ['Before', "Evasion", "Poison", "Time"]
        accuracies[-1] = [i/100 for i in accuracies[-1]]  # The record time is about seconds(s)

        accuracy_str = ', '.join([f'{title[i]}: {np.mean(acc)*100:.2f}±{np.std(acc)*100:.2f}' for i, acc in enumerate(accuracies)])
        self.logger.info(f'Accuracy: {accuracy_str}\n')
