import torch
import time
import torch.optim as optim
from src import train, eval_model, feature_ablation, graph_smoothing_level, DataLoader, dirichlet_energy
from src.graph_utils import *
from src.utils import *
from calculate_influence import GraphInfluenceModule
import os.path as osp


def summarizing_result(ori_run_best_results, rewire_run_best_results, args):

    ori_val_acc_list, ori_val_loss_list, ori_test_acc_list, ori_test_loss_list = [], [], [], []
    rewire_val_acc_list, rewire_val_loss_list, rewire_test_acc_list, rewire_test_loss_list = [], [], [], []
    for run in range(args.runs):
        ori_val_loss_list.append(ori_run_best_results[run][1])
        ori_test_loss_list.append(ori_run_best_results[run][2])
        ori_val_acc_list.append(ori_run_best_results[run][4])
        ori_test_acc_list.append(ori_run_best_results[run][5])

        rewire_val_loss_list.append(rewire_run_best_results[run][1])
        rewire_test_loss_list.append(rewire_run_best_results[run][2])
        rewire_val_acc_list.append(rewire_run_best_results[run][4])
        rewire_test_acc_list.append(rewire_run_best_results[run][5])

    ori_val_acc_list, ori_val_loss_list = torch.tensor(ori_val_acc_list), torch.tensor(ori_val_loss_list)
    ori_test_acc_list, ori_test_loss_list = torch.tensor(ori_test_acc_list), torch.tensor(ori_test_loss_list)
    rewire_val_acc_list, rewire_val_loss_list = torch.tensor(rewire_val_acc_list), torch.tensor(rewire_val_loss_list)
    rewire_test_acc_list, rewire_test_loss_list = torch.tensor(rewire_test_acc_list), torch.tensor(rewire_test_loss_list)

    ori_sva, ori_mva = torch.std_mean(ori_val_acc_list)
    ori_svl, ori_mvl = torch.std_mean(ori_val_loss_list)
    ori_stl, ori_mtl = torch.std_mean(ori_test_loss_list)
    ori_sta, ori_mta = torch.std_mean(ori_test_acc_list)

    rewire_sva, rewire_mva = torch.std_mean(rewire_val_acc_list)
    rewire_svl, rewire_mvl = torch.std_mean(rewire_val_loss_list)
    rewire_stl, rewire_mtl = torch.std_mean(rewire_test_loss_list)
    rewire_sta, rewire_mta = torch.std_mean(rewire_test_acc_list)

    ori_results = [ori_mvl, ori_mva, ori_mtl, ori_mta, ori_svl, ori_sva, ori_stl, ori_sta]
    re_results = [rewire_mvl, rewire_mva, rewire_mtl, rewire_mta, rewire_svl, rewire_sva, rewire_stl, rewire_sta]

    return ori_results, re_results


def get_influence(model, data, args, checkpoint_dir, seed, device, eval_metric, removal_candidate_idxs=None, num_folds=None, removal_candidates = None, insertion_candidates = None, eval_node_idxs=None):
    """
    Compute the influence scores of edge removal and insertion for a given graph model.

    Parameters:
    ----------
    model : torch.nn.Module
        The trained graph neural network model to evaluate influence.
    data : torch_geometric.data.Data
        The graph data used for evaluation.
    args : argparse.Namespace
        The arguments containing configuration settings (e.g., influence hyperparameters).
    checkpoint_dir : str
        Directory where influence scores are saved or loaded.
    seed : int
        Random seed used for reproducibility.
    device : torch.device
        Device on which computations will be performed (CPU or CUDA).
    eval_metric : str
        Metric used for influence estimation ('mean_validation_loss', 'feature_ablation', or 'GSL').
    removal_candidate_idxs (Will be removed in the updated version) : torch.Tensor or None, optional
        Indices of specific edge removal candidates to evaluate (used for ablation or selection).
    num_folds : int or None, optional
        Number of folds used for k-fold validation, if applicable.
    removal_candidates : torch.Tensor or None, optional
        Specific edge pairs to evaluate for removal influence. If None, candidates will be automatically selected.
    insertion_candidates : torch.Tensor or None, optional
        Specific edge pairs to evaluate for insertion influence. If None, candidates will be automatically generated.

    Returns:
    -------
    removal_inf : torch.Tensor
        Influence scores for each candidate edge removal.
    insertion_inf : torch.Tensor
        Influence scores for each candidate edge insertion.
    """
    influence_dir = osp.join(checkpoint_dir, "influence", args.hessian_type, eval_metric, f"{args.damp}_{args.scale}_{args.lissa_iter}_{args.pbrf_weight_decay}_{args.num_folds}_{args.num_removal_candidates}_{args.num_insertion_candidates}")
    os.makedirs(influence_dir, exist_ok=True)
    influence_path = osp.join(influence_dir, f"{seed}.pth")

    is_removal_inf_saved = False
    is_insertion_inf_saved = False
    save_inf = True
    inf_dict = {}

    # Load the influence if it exists
    if osp.isfile(influence_path) and save_inf:
        print(f'Load the Influence...')
        saved_influence = torch.load(influence_path, weights_only=True)
        
        if "removal_inf" in saved_influence.keys():
            removal_inf = saved_influence["removal_inf"]
            is_removal_inf_saved = True
            inf_dict["removal_inf"] = removal_inf
            removal_inf = removal_inf.to(device)

            if removal_candidate_idxs is not None and removal_inf.numel() > removal_candidate_idxs.numel():
                removal_inf = removal_inf[removal_candidate_idxs]

            if removal_inf.numel()/args.num_folds != removal_candidates.shape[0]:
                is_removal_inf_saved = False

        if "insertion_inf" in saved_influence.keys():
            insertion_inf = saved_influence["insertion_inf"]
            is_insertion_inf_saved = True
            inf_dict["insertion_inf"] = insertion_inf
            insertion_inf = insertion_inf.to(device)

    if not (is_removal_inf_saved and is_insertion_inf_saved):
        influence_module = GraphInfluenceModule(model, data, args, eval_metric, num_folds, eval_node_idxs)

    # Calculate the influence
    if not is_removal_inf_saved:
        if (args.rewire_type=='percentages' and args.removal_ratio > 0) or (args.rewire_type=='numbers' and args.num_removals > 0):
            print(f'Calculate the Influence of edge removal...')    
            start_time = time.time()

            removal_inf, retrain_removal_inf, _, _, _, _ = influence_module.calculate_influence(removal_candidates, 'edge_removal')

            print(f'Consumed time: {time.time()-start_time:.2f}s')
            
            if args.hessian_type == "hessian":
                inf_dict["removal_inf"] = retrain_removal_inf
            elif args.hessian_type == "GNH":
                inf_dict["removal_inf"] = removal_inf
            else:
                raise ValueError

            removal_inf = removal_inf.to(device)
        else:
            removal_inf = torch.tensor([]).to(device)

    if not is_insertion_inf_saved:
        if (args.rewire_type=='percentages' and args.insertion_ratio > 0) or (args.rewire_type=='numbers' and args.num_insertions > 0) > 0:
            print(f'Calculate the Influence of edge insertion...')    
            start_time = time.time()
            insertion_inf, retrain_insertion_inf, _, _, _, _ = influence_module.calculate_influence(insertion_candidates, 'edge_insertion')
            print(f'Consumed time: {time.time()-start_time:.2f}s')

            if args.hessian_type == "hessian":
                inf_dict["insertion_inf"] = retrain_insertion_inf
            elif args.hessian_type == "GNH":
                inf_dict["insertion_inf"] = insertion_inf

            insertion_inf = insertion_inf.to(device)
        else:
            insertion_inf = torch.tensor([]).to(device)
        
    if save_inf:
        torch.save(inf_dict, influence_path)

    return removal_inf, insertion_inf, inf_dict


def get_removal_insertion_edges(removal_inf, insertion_inf, args, removal_candidates, insertion_candidates):
    lower_better = lower_is_better(args)
    if args.rewire_type == 'attack':
        lower_better = not lower_better
    
    if removal_inf.numel() != 0:
        if lower_better:
            all_pos_removal_idxs = (removal_inf < 0).min(dim=1)[0]
        else:
            all_pos_removal_idxs = (removal_inf > 0).min(dim=1)[0]
        
        all_pos_removal_inf = removal_inf[all_pos_removal_idxs]
        sum_all_pos_removal_inf = all_pos_removal_inf.sum(dim=1)
        sorted_removal_inf, sorted_removal_idxs = sum_all_pos_removal_inf.sort(descending=not lower_better)

        removal_idxs = all_pos_removal_idxs.nonzero().squeeze()[sorted_removal_idxs]

        if args.rewire_type == 'percentages':
            num_edge_removal = (all_pos_removal_idxs.sum() * args.removal_ratio).to(torch.int)
        elif args.rewire_type == 'numbers':
            num_edge_removal = torch.min(all_pos_removal_idxs.sum(), torch.tensor(args.num_removals))
        elif args.rewire_type == 'attack':
            num_edge_removal = torch.min(all_pos_removal_idxs.sum(), torch.tensor(args.num_removals))
        
    else:
        removal_idxs = torch.tensor([]).to(removal_inf.device).to(torch.long)
        num_edge_removal = 0

    if insertion_inf.numel() != 0:
        if lower_better:
            all_pos_insertion_idxs = (insertion_inf < 0).min(dim=1)[0]
        else:
            all_pos_insertion_idxs = (insertion_inf > 0).min(dim=1)[0]

        all_pos_insertion_inf = insertion_inf[all_pos_insertion_idxs]
        sum_all_pos_insertion_inf = all_pos_insertion_inf.sum(dim=1)
        sorted_insertion_inf, sorted_insertion_idxs = sum_all_pos_insertion_inf.sort(descending=not lower_better)

        insertion_idxs = all_pos_insertion_idxs.nonzero().squeeze()[sorted_insertion_idxs]

        if args.rewire_type == 'percentages':
            num_edge_insertion = (all_pos_insertion_idxs.sum() * args.insertion_ratio).to(torch.int)
        elif args.rewire_type == 'numbers':
            num_edge_insertion = torch.min(all_pos_insertion_idxs.sum(), torch.tensor(args.num_insertions))
        elif args.rewire_type == 'attack':
            num_edge_insertion = torch.min(all_pos_insertion_idxs.sum(), torch.tensor(args.num_insertions))
        
    else:
        insertion_idxs = torch.tensor([]).to(insertion_inf.device).to(torch.long)
        num_edge_insertion = 0

    print(f"Num edge removal: {num_edge_removal}, num edge insertion: {num_edge_insertion}\n")

    removal_idxs = removal_idxs[:num_edge_removal]
    insertion_idxs = insertion_idxs[:num_edge_insertion]

    return removal_candidates[removal_idxs], insertion_candidates[insertion_idxs], removal_idxs, insertion_idxs

def get_topk_indices_from_two_tensors(A: torch.Tensor, B: torch.Tensor, K: int):
    assert A.dim() == 1 and B.dim() == 1, "A and B must be 1D tensors."
    assert K <= A.numel() + B.numel(), "K must be less than or equal to the total number of elements."

    len_A = A.numel()
    len_B = B.numel()

    combined = torch.cat([A, B])

    # Get indices of the smallest K values
    _, topk_indices = torch.topk(combined, K, largest=True)

    # Boolean masks
    in_A_mask = topk_indices < len_A
    in_B_mask = ~in_A_mask

    # Convert to original indices
    A_indices = topk_indices[in_A_mask]
    B_indices = topk_indices[in_B_mask] - len_A

    return A_indices, B_indices

if __name__ == '__main__':
    from src import SGC, GCN, GNN
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='Cora_public')

    # Model Arguments
    parser.add_argument('--model', type=str, default='GCN', choices=['SGC', 'GCN'])
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_dim', type=int, default=32)
    parser.add_argument('--linear', type=int, default=0)
    parser.add_argument('--bias', type=int, default=0)

    # Learning Arguments
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--weight_decay', type=float, default=0.001)

    # Influence Function Arguments
    parser.add_argument('--hessian_type', type=str, default='GNH', choices=['hessian', 'GNH'])
    parser.add_argument('--damp', type=float, default=0.1)
    parser.add_argument('--scale', type=float, default=1.0)
    parser.add_argument('--lissa_iter', type=int, default=10000)
    parser.add_argument('--eval_metric', type=str, default='mean_validation_loss', choices=['dirichlet_energy', 'mvl_with_kl_reg', 'feature_ablation', 'GSL', 'mean_validation_loss', 'k_hop_grad', 'MVL_FA'])
    parser.add_argument('--pbrf_weight_decay', type=float, default=0.0)
    
    # Re-wiring arguments
    parser.add_argument("--rewire_type", type=str, default="numbers", choices=["numbers", "percentages", "attack"])
    parser.add_argument("--rewire_method", type=str, default="random", choices=["top", "random"])
    parser.add_argument("--insertion_ratio", type=float, default=0.0)
    parser.add_argument("--removal_ratio", type=float, default=0.0)
    parser.add_argument("--num_removals", type=int, default=100)
    parser.add_argument("--num_insertions", type=int, default=100)
    
    # Other Arguments
    parser.add_argument("--runs", type=int, default=10)
    parser.add_argument("--num_folds", type=int, default=1)
    parser.add_argument("--num_insertion_candidates", type=int, default=10000)
    parser.add_argument("--num_removal_candidates", type=int, default=10000)

    args = parser.parse_args()
    args.linear = bool(args.linear)
    args.bias = bool(args.bias)
    print(args)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    checkpoint_dir = osp.join('checkpoints', args.dataset, f"{args.model}_{args.num_layers}_{args.hidden_dim}_{args.linear}_{args.bias}", f"{args.lr}_{args.epochs}_{args.weight_decay}")
    ori_model_dir = osp.join(checkpoint_dir, "vanilla")

    os.makedirs(ori_model_dir, exist_ok=True)

    WD = args.weight_decay
    PBRF_WD = args.pbrf_weight_decay
    if args.hessian_type == 'hessian':
        args.damp = args.weight_decay
        print('Warning. args.damp should be the same with args.weight_decay when args.hessian_type is hessian.')

    dataset = DataLoader(args.dataset, root='datasets')
    args.num_classes = dataset.num_classes
    data = dataset[0]
    data.edge_weight = torch.ones((data.edge_index.shape[1], ))
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    data.edge_weight = data.edge_weight.to(device)
    data.y = data.y.to(device)

    if args.rewire_type == "attack":
        args.num_removals = int(data.edge_index.shape[1]/2 * args.removal_ratio)

        if args.hessian_type == "GNH":
            args.num_insertions = args.num_removals

    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    set_seed(SEEDS[0])

    eval_node_idxs = get_eval_node_idxs(data, args.eval_metric, SEEDS[0])
    ori_run_best_results, rewire_run_best_results = [], []
    val_removal_inf_list, test_removal_inf_list = [], []
    ori_run_metrics, rewire_run_metrics = [], []

    if args.eval_metric == "feature_ablation":
        exact_k_hop = find_k_hop_neighborhoods(data, args.num_layers)
    
    for run in range(args.runs):
        seed = SEEDS[run]
        model_path = osp.join(ori_model_dir, f"{seed}.pth")

        if 'public' not in args.dataset:
            percls_trn = int(round(0.6*len(data.y)/dataset.num_classes))
            val_lb = int(round(0.2*len(data.y)))
            data = random_planetoid_splits(data, dataset.num_classes, percls_trn, val_lb, seed)
        
        set_seed(seed)
        model = GNN(
                    name=args.model,
                    in_dim=dataset.num_node_features, 
                    hidden_dim=args.hidden_dim, 
                    num_classes=dataset.num_classes, 
                    num_layers=args.num_layers,
                    linear=args.linear,
                    bias=args.bias
                )
        if osp.isfile(model_path):
            best_state_dict = torch.load(model_path, weights_only=True)
            model.load_state_dict(best_state_dict)
            model = model.to(device)
            ori_best_result = eval_model(data, model, device)
        else:
            model = model.to(device)
            optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

            ori_best_val_loss = torch.inf
            for epoch in range(1,args.epochs+1):
                train(data, model, optimizer, device)
                result = eval_model(data, model, device)
                train_loss, val_loss, test_loss, train_acc, val_acc, test_acc = result

                if ori_best_val_loss > val_loss:
                    ori_best_result = result
                    ori_best_val_loss = val_loss
                    ori_best_state_dict = {k: v.clone().detach() for k, v in model.state_dict().items()}

                if epoch % 400 == 0:
                    print("-----------------------------------------------")
                    print(f"Epoch: {epoch}, train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, test_loss: {test_loss:.4f}")
                    print(f"Train acc: {train_acc*100:.2f}%, val acc: {val_acc*100:.2f}%, test_acc: {test_acc*100:.2f}%")
                    print("-----------------------------------------------")

            print("-----------------------------------------------")
            print(f"Best results, train loss: {ori_best_result[0]:.4f}, val loss: {ori_best_result[1]:.4f}, test_loss: {ori_best_result[2]:.4f}")
            print(f"Train acc: {ori_best_result[3]*100:.2f}%, val acc: {ori_best_result[4]*100:.2f}%, test_acc: {ori_best_result[5]*100:.2f}%")
            print("-----------------------------------------------")

            torch.save(ori_best_state_dict, model_path)
            model.load_state_dict(ori_best_state_dict)
        
        if args.eval_metric == "feature_ablation":
            with torch.no_grad():
                ori_best_fa = feature_ablation(eval_node_idxs, model, data, exact_k_hop)
                ori_run_metrics.append(ori_best_fa)
        elif args.eval_metric == "GSL":
            with torch.no_grad():
                ori_best_gsl = graph_smoothing_level(model, data)
                ori_run_metrics.append(ori_best_gsl)
        elif args.eval_metric == "dirichlet_energy":
            with torch.no_grad():
                ori_best_dl = dirichlet_energy(model, data, data.edge_index)
                ori_run_metrics.append(ori_best_dl)
            
        ori_run_best_results.append(ori_best_result)

        set_seed(seed)
        removal_candidates = get_edge_removal_candidates(data, args.num_removal_candidates)
        insertion_candidates = get_edge_insertion_candidates(data, args.num_insertion_candidates)
        removal_inf, insertion_inf, inf_dict = get_influence(model, data, args, checkpoint_dir, seed, device, args.eval_metric, num_folds=args.num_folds, removal_candidates=removal_candidates, insertion_candidates=insertion_candidates, eval_node_idxs=eval_node_idxs)
        
        removal_edges, insertion_edges, removal_idxs, insertion_idxs = get_removal_insertion_edges(removal_inf, insertion_inf, args, removal_candidates, insertion_candidates)
        
        if args.rewire_type == "attack":
            removal_idxs_idxs, insertion_idxs_idxs =get_topk_indices_from_two_tensors(removal_inf.sum(dim=1)[removal_idxs], insertion_inf.sum(dim=1)[insertion_idxs], removal_idxs.numel())
            removal_edges = removal_edges[removal_idxs_idxs]
            insertion_edges = insertion_edges[insertion_idxs_idxs]
        val_removal_inf_list.append(removal_inf[removal_idxs].sum().item())
        rewired_graph = edge_rewiring(data, removal_edges, insertion_edges).detach()

        set_seed(seed)
        new_model = GNN(
                    name=args.model,
                    in_dim=dataset.num_node_features, 
                    hidden_dim=args.hidden_dim, 
                    num_classes=dataset.num_classes, 
                    num_layers=args.num_layers,
                    linear=args.linear,
                    bias=args.bias
                )
        new_model = new_model.to(device)
        new_optimizer = optim.SGD(new_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        rewire_best_val_loss = torch.inf
        for epoch in range(1,args.epochs+1):
            train(rewired_graph, new_model, new_optimizer, device)
            result = eval_model(rewired_graph, new_model, device)
            train_loss, val_loss, test_loss, train_acc, val_acc, test_acc = result

            if rewire_best_val_loss > val_loss:
                rewire_best_result = result
                rewire_best_val_loss = val_loss
                rewire_best_state_dict = {k: v.clone().detach() for k, v in new_model.state_dict().items()}

            if epoch % 100 == 0:
                print("-----------------------------------------------")
                print(f"Epoch: {epoch}, train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, test_loss: {test_loss:.4f}")
                print(f"Train acc: {train_acc*100:.2f}%, val acc: {val_acc*100:.2f}%, test_acc: {test_acc*100:.2f}%")
                print("-----------------------------------------------")

        print("-----------------------------------------------")
        print(f"Best results, train loss: {rewire_best_result[0]:.4f}, val loss: {rewire_best_result[1]:.4f}, test_loss: {rewire_best_result[2]:.4f}")
        print(f"Train acc: {rewire_best_result[3]*100:.2f}%, val acc: {rewire_best_result[4]*100:.2f}%, test_acc: {rewire_best_result[5]*100:.2f}%")
        print("-----------------------------------------------")
        
        new_model.load_state_dict(rewire_best_state_dict)
        if args.eval_metric == "feature_ablation":
            with torch.no_grad():
                rewire_best_fa = feature_ablation(eval_node_idxs, new_model, rewired_graph, exact_k_hop)
            rewire_run_metrics.append(rewire_best_fa)
        elif args.eval_metric == "GSL":
            with torch.no_grad():
                rewire_best_gsl = graph_smoothing_level(new_model, rewired_graph)
            rewire_run_metrics.append(rewire_best_gsl)
        elif args.eval_metric == "dirichlet_energy":
            with torch.no_grad():
                rewire_best_dl = dirichlet_energy(new_model, rewired_graph, data.edge_index)
            rewire_run_metrics.append(rewire_best_dl)
        
        rewire_run_best_results.append(rewire_best_result)

    # Code for summarizing results  
    ori_results, re_results = summarizing_result(ori_run_best_results, rewire_run_best_results, args)
    ori_mvl, ori_mva, ori_mtl, ori_mta, ori_svl, ori_sva, ori_stl, ori_sta = ori_results
    rewire_mvl, rewire_mva, rewire_mtl, rewire_mta, rewire_svl, rewire_sva, rewire_stl, rewire_sta = re_results

    mean_val_removal_inf = sum(val_removal_inf_list)/args.runs

    print("--------------------------------------")
    print("Original result")
    print(f"val_loss: {ori_mvl:.4f}, test_loss: {ori_mtl:.4f}, val_acc: {ori_mva*100:.2f}, test_acc : {ori_mta*100:.2f}+-{ori_sta*100:.2f}")
    print("--------------------------------------")

    print("--------------------------------------")
    print("Rewired result")
    print(f"val_loss: {rewire_mvl:.4f}, test_loss: {rewire_mtl:.4f}, val_acc: {rewire_mva*100:.2f}, test_acc : {rewire_mta*100:.2f}+-{rewire_sta*100:.2f}")
    print("--------------------------------------")
    
    if args.eval_metric in ["feature_ablation", "GSL", "dirichlet_energy"]:
        ori_run_metrics = torch.tensor(ori_run_metrics)
        rewire_run_metrics = torch.tensor(rewire_run_metrics)

        ori_std, ori_mean = torch.std_mean(ori_run_metrics)
        rewire_std, rewire_mean = torch.std_mean(rewire_run_metrics)
        print(f"\n Original result: {ori_mean}+-{ori_std} Rewire result: {rewire_mean}+-{rewire_std}\n")