import torch
import torch.optim as optim
from src import train, eval_model, DataLoader, feature_ablation, graph_smoothing_level, dirichlet_energy
from src.graph_utils import *
from src.utils import *
import os.path as osp
from improve_gnns import get_eval_node_idxs, lower_is_better
from torch_geometric.utils import dropout_edge, add_random_edge

def apply_dropedge(graph, num_edge_removal):
    dropedge_ratio = num_edge_removal*2 / graph.edge_index.shape[1]
    dropedge_ratio = dropedge_ratio.to("cpu")
    
    new_graph = graph.clone()
    new_edge, new_edge_mask = dropout_edge(graph.edge_index, dropedge_ratio)
    new_graph.edge_index = new_edge
    
    if new_graph.edge_weight is not None:
        new_graph.edge_weight = graph.edge_weight[new_edge_mask]
    
    return new_graph

def apply_addedge(graph, num_edge_insertion):
    addedge_ratio = num_edge_insertion*2 / graph.edge_index.shape[1]
    addedge_ratio = addedge_ratio.to("cpu")

    new_graph = graph.clone()
    new_edge, added_edge = add_random_edge(graph.edge_index, addedge_ratio.item())
    new_graph.edge_index = new_edge
    
    if new_graph.edge_weight is not None:
        new_edge_weight = torch.cat([new_graph.edge_weight, torch.ones((added_edge.shape[1],))])
        new_graph.edge_weight = new_edge_weight

    return new_graph

def get_num_removal(args):
    if args.num_removals != -1:
        num_edge_removal = torch.tensor(args.num_removals)
    else:
        checkpoint_dir = osp.join('checkpoints', args.dataset, f"{args.model}_{args.num_layers}_{args.hidden_dim}_{args.linear}_{args.bias}", f"{args.lr}_{args.epochs}_{args.weight_decay}")
        influence_dir = osp.join(checkpoint_dir, "influence", args.hessian_type, args.eval_metric, f"{args.damp}_{args.scale}_{args.lissa_iter}_{args.pbrf_weight_decay}_{args.num_folds}_{args.num_removal_candidates}_{args.num_insertion_candidates}")
        influence_path = osp.join(influence_dir, f"{seed}.pth")

        influence = torch.load(influence_path)
        removal_influence = influence['removal_inf']

        if lower_is_better(args):
            num_edge_removal = ((removal_influence < 0).sum() * args.removal_ratio).to(torch.int)
        else:
            num_edge_removal = ((removal_influence > 0).sum() * args.removal_ratio).to(torch.int)

    return num_edge_removal

def get_num_insertion(args):
    checkpoint_dir = osp.join('checkpoints', args.dataset, f"{args.model}_{args.num_layers}_{args.hidden_dim}_{args.linear}_{args.bias}", f"{args.lr}_{args.epochs}_{args.weight_decay}")
    influence_dir = osp.join(checkpoint_dir, "influence", args.hessian_type, args.eval_metric, f"{args.damp}_{args.scale}_{args.lissa_iter}_{args.pbrf_weight_decay}_{args.num_folds}_{args.num_removal_candidates}_{args.num_insertion_candidates}")
    influence_path = osp.join(influence_dir, f"{seed}.pth")

    influence = torch.load(influence_path)
    insertion_influence = influence['insertion_inf']

    if lower_is_better(args):
        num_edge_insertion = ((insertion_influence < 0).sum() * args.insertion_ratio).to(torch.int)
    else:
        num_edge_insertion = ((insertion_influence > 0).sum() * args.insertion_ratio).to(torch.int)


    return num_edge_insertion

if __name__ == '__main__':
    from src import SGC, GCN, GNN
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='Cora_public')
    parser.add_argument('--model', type=str, default='GCN', choices=['SGC', 'GCN', 'GAT'])
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--hidden_dim', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--weight_decay', type=float, default=0.001)
    parser.add_argument('--eval_metric', type=str, default='GSL', choices=['dirichlet_energy', 'feature_ablation', 'GSL', 'mean_validation_loss', 'k_hop_grad', 'k_hop_grad_square', 'local_k_hop_grad', 'indiv_k_hop_grad'])
    parser.add_argument('--linear', type=int, default=0)
    parser.add_argument('--bias', type=int, default=0)
    parser.add_argument("--runs", type=int, default=1)
    parser.add_argument("--method", type=str, default="vanilla", choices=["vanilla", "dropedge", "addedge"])

    parser.add_argument("--removal_ratio", type=float, default=0.0)
    parser.add_argument("--insertion_ratio", type=float, default=0.1)

    # Influence Function Arguments
    parser.add_argument('--hessian_type', type=str, default='GNH', choices=['hessian', 'GNH'])
    parser.add_argument('--damp', type=float, default=0.1)
    parser.add_argument('--scale', type=float, default=1.0)
    parser.add_argument('--lissa_iter', type=int, default=10000)
    parser.add_argument('--pbrf_weight_decay', type=float, default=0.0)
    parser.add_argument("--num_folds", type=int, default=1)
    parser.add_argument("--num_insertion_candidates", type=int, default=10000)
    parser.add_argument("--num_removal_candidates", type=int, default=10000)
    parser.add_argument("--num_heads", type=int, default=8)

    parser.add_argument("--num_removals", type=int, default=-1)

    args = parser.parse_args()
    args.linear = bool(args.linear)
    args.bias = bool(args.bias)
    print(args)
    ori_model_dir = osp.join('checkpoints', args.dataset, f"{args.model}_{args.num_layers}_{args.hidden_dim}_{args.linear}_{args.bias}", f"{args.lr}_{args.epochs}_{args.weight_decay}", args.method)
    
    os.makedirs(ori_model_dir, exist_ok=True)
    WD = args.weight_decay

    #dataset = Planetoid(root='datasets', name=args.dataset)
    dataset = DataLoader(args.dataset, root='datasets')
    args.num_classes = dataset.num_classes
    data = dataset[0]
    data.edge_weight = torch.ones((data.edge_index.shape[1], ))

    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    set_seed(SEEDS[0])
    eval_node_idxs = get_eval_node_idxs(data, args.eval_metric, SEEDS[0])
    if args.eval_metric == "feature_ablation":
        exact_k_hop = find_k_hop_neighborhoods(data, args.num_layers)
    ori_run_best_results, rewire_run_best_results, run_metric = [], [], []

    for run in range(args.runs):
        seed = SEEDS[run]

        model_path = osp.join(ori_model_dir, f"{seed}.pth")

        if 'public' not in args.dataset:
            percls_trn = int(round(0.6*len(data.y)/dataset.num_classes))
            val_lb = int(round(0.2*len(data.y)))
            data = random_planetoid_splits(data, dataset.num_classes, percls_trn, val_lb, seed)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        if args.method == "vanilla":
            rewired_data = data
        elif args.method == "dropedge":
            set_seed(seed)
            num_removal = get_num_removal(args)
            rewired_data = apply_dropedge(data, num_removal)
        elif args.method == "addedge":
            set_seed(seed)
            num_insertion = get_num_insertion(args)
            rewired_data = apply_addedge(data, num_insertion)
        else:
            raise ValueError

        set_seed(seed)
        model = GNN(
                    name=args.model,
                    in_dim=dataset.num_node_features, 
                    hidden_dim=args.hidden_dim, 
                    num_classes=dataset.num_classes, 
                    num_layers=args.num_layers,
                    linear=args.linear,
                    bias=args.bias,
                    num_heads=args.num_heads
                )
        if osp.isfile(model_path) and args.method not in ["dropedge", "addedge"]:
            ori_best_state_dict = torch.load(model_path, weights_only=True)
            model.load_state_dict(ori_best_state_dict)
            model = model.to(device)
            ori_best_result = eval_model(rewired_data, model, device)
        else:
            model = model.to(device)
            optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

            ori_best_val_loss = torch.inf
            for epoch in range(1,args.epochs+1):
                # result := train_loss, val_loss, test_loss, train_acc, val_acc, test_acc
                train(rewired_data, model, optimizer, device)
                result = eval_model(rewired_data, model, device)
                train_loss, val_loss, test_loss, train_acc, val_acc, test_acc = result

                if ori_best_val_loss > val_loss:
                    ori_best_result = result
                    ori_best_val_loss = val_loss
                    ori_best_state_dict = {k: v.clone().detach() for k, v in model.state_dict().items()}

                if epoch % 100 == 0:
                    print("-----------------------------------------------")
                    print(f"Epoch: {epoch}, train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, test_loss: {test_loss:.4f}")
                    print(f"Train acc: {train_acc*100:.2f}%, val acc: {val_acc*100:.2f}%, test_acc: {test_acc*100:.2f}%")
                    print("-----------------------------------------------")

            print("-----------------------------------------------")
            print(f"Best results, train loss: {ori_best_result[0]:.4f}, val loss: {ori_best_result[1]:.4f}, test_loss: {ori_best_result[2]:.4f}")
            print(f"Train acc: {ori_best_result[3]*100:.2f}%, val acc: {ori_best_result[4]*100:.2f}%, test_acc: {ori_best_result[5]*100:.2f}%")
            print("-----------------------------------------------")

            torch.save(ori_best_state_dict, model_path)
        ori_run_best_results.append(ori_best_result)
        model.load_state_dict(ori_best_state_dict)

        if args.eval_metric == "feature_ablation":
            with torch.no_grad():
                best_metric = feature_ablation(eval_node_idxs, model, rewired_data, exact_k_hop)
            run_metric.append(best_metric)
        elif args.eval_metric == "GSL":
            with torch.no_grad():
                best_metric = graph_smoothing_level(model, rewired_data)
            run_metric.append(best_metric)
        elif args.eval_metric == "dirichlet_energy":
            with torch.no_grad():
                best_metric = dirichlet_energy(model, rewired_data, data.edge_index)
            run_metric.append(best_metric)
        elif args.eval_metric == 'mean_validation_loss':
            with torch.no_grad():
                best_metric = eval_model(rewired_data, model, device)[1]
            run_metric.append(best_metric)


    ori_val_loss = 0
    ori_test_loss = 0
    ori_val_acc = 0
    ori_test_acc = 0

    val_acc_list = []
    val_loss_list = []
    test_acc_list = []
    test_loss_list = []
    for run in range(args.runs):
        val_loss_list.append(ori_run_best_results[run][1])
        test_loss_list.append(ori_run_best_results[run][2])
        val_acc_list.append(ori_run_best_results[run][4])
        test_acc_list.append(ori_run_best_results[run][5])

    ori_val_loss = ori_val_loss / args.runs
    ori_test_loss = ori_test_loss / args.runs
    ori_val_acc = ori_val_acc / args.runs

    val_acc_list = torch.tensor(val_acc_list)
    test_acc_list = torch.tensor(test_acc_list)
    val_loss_list = torch.tensor(val_loss_list)
    test_loss_list = torch.tensor(test_loss_list)

    mean_val_acc = val_acc_list.mean()
    mean_val_loss = val_loss_list.mean()
    mean_test_loss = test_loss_list.mean()
    test_std, mean_test_acc = torch.std_mean(test_acc_list)

    print("\n--------------------------------------")
    print("Original result")
    print(f"val_loss: {mean_val_loss:.4f}, test_loss: {mean_test_loss:.4f}, val_acc: {mean_val_acc*100:.2f}, test_acc : {mean_test_acc*100:.2f}+-{test_std*100:.2f}")
    print("--------------------------------------")

    run_metric = torch.tensor(run_metric)
    ori_std, ori_mean = torch.std_mean(run_metric)
    print(f"\n Original result: {ori_mean}+-{ori_std}\n")