import time
import ot
import os
from BAPG import *
from robust_gw import *
from unbalancedgw.vanilla_ugw_solver import log_ugw_sinkhorn
from collections import defaultdict
import pickle
import warnings
from partial_gw import pu_gw_emd
from GromovWassersteinFramework import *
import GromovWassersteinGraphToolkit as GwGt
from gromovWassersteinAveraging import *
import spectralGW as sgw
from srGW import *
from tqdm import tqdm 
from MGGW import mask_guided_gw_solver

warnings.filterwarnings("ignore")
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

database = 'bzr'  # proteins / enzymes / powerlaw
noise_ratio = 0.01
subgraph_ratio = 0.5
use_full_data = True

def nx_graph_sparse_random_noise(
    G,
    noise_ratio=0.05,
    random_range=(0.0, 1.0),
    seed=None,
    keep_diagonal_zero=True,
):
    """
    Returns:
        C_noisy : noisy weighted adjacency matrix
        C_clean : original adjacency matrix
        corrupted_edges : set of (i,j), i<j, corrupted entries
    """
    rng = np.random.default_rng(seed)

    C_clean = nx.to_numpy_array(G, dtype=np.float64)
    n = C_clean.shape[0]

    iu = np.triu_indices(n, k=1)
    M = len(iu[0])
    k = int(np.floor(noise_ratio * M))

    C_noisy = C_clean.copy()
    corrupted_edges = set()

    if k > 0:
        pick = rng.choice(M, size=k, replace=False)
        I = iu[0][pick]
        J = iu[1][pick]

        low, high = random_range
        random_vals = rng.uniform(low, high, size=k)

        for i, j, val in zip(I, J, random_vals):
            C_noisy[i, j] = val
            C_noisy[j, i] = val
            corrupted_edges.add((int(i), int(j)))

    if keep_diagonal_zero:
        np.fill_diagonal(C_noisy, 0.0)

    return C_noisy, C_clean, corrupted_edges

def edge_accuracy_soft_greedy(P, edges_src, edges_tgt, gt_idx, alpha, beta):
    """
    For each source edge (i,j), find the target edge (k,l) with max score.
    """
    if P is None or len(edges_src) == 0 or len(edges_tgt) == 0:
        return 0.0

    if isinstance(P, torch.Tensor):
        P = P.detach().cpu().numpy()
    else:
        P = np.asarray(P, dtype=np.float64)

    m, n = P.shape
    if np.isscalar(alpha) or alpha is None:
        alpha = np.zeros((m, n), dtype=np.float64)
    if np.isscalar(beta) or beta is None:
        beta = np.zeros((m, n), dtype=np.float64)
    if isinstance(alpha, torch.Tensor):
        alpha = alpha.detach().cpu().numpy()
    if isinstance(beta, torch.Tensor):
        beta = beta.detach().cpu().numpy()

    edges_tgt_sorted = [tuple(sorted((int(k), int(l)))) for (k, l) in edges_tgt]

    correct = 0
    for (i, j) in edges_src:
        gi, gj = int(gt_idx[i]), int(gt_idx[j])
        gt_edge = tuple(sorted((gi, gj)))

        best_score = -np.inf
        best_edge = None

        for (k, l) in edges_tgt_sorted:
            score_direct = (P[i, k] - alpha[i, k]) * (P[j, l] - beta[j, l])
            score_cross = (P[i, l] - alpha[i, l]) * (P[j, k] - beta[j, k])
            s = max(score_direct, score_cross)

            if s > best_score:
                best_score = s
                best_edge = (k, l)

        if best_edge == gt_edge:
            correct += 1

    return correct / len(edges_src)

graphs = []
alphas = 0.5

if database == 'reddit':
    print('------------------Node Matching on REDDIT---------------')
    with open('data/REDDIT-BINARY/matching.pk', 'rb') as f:
        graphs = pickle.load(f)[:50]

if database == 'enzymes':
    print('------------------Node Matching on ENZYMES---------------')
    with open('data/ENZYMES_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)
        
if database == 'bzr':
    print('------------------Node Matching on BZR---------------')
    with open('data/BZR_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)
        
if database == 'Fingerprint':
    print('------------------Node Matching on Fingerprint---------------')
    with open('data/Fingerprint_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)

if database == 'dhfr':
    print('------------------Node Matching on DHFR---------------')
    with open('data/DHFR_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)

def build_subgraph_noisy_pairs(graph_list, noise_ratio, seed, use_full_data, subgraph_ratio, subgraph_type):
    src_graphs = []
    tgt_graphs = []
    gt_list = []
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    for i, G_full in enumerate(graph_list):
        if (not use_full_data) and random.random() >= 0.05:
            continue
        if G_full.number_of_nodes() <2:
            continue
        if subgraph_type == 'connected':
            subG, bfs_idx = connected_subgraph(G_full, subgraph_ratio)
        elif subgraph_type == 'random':
            subG, bfs_idx = random_subgraph(G_full, subgraph_ratio)
        else:
            raise ValueError(f"Unknown subgraph_type: {subgraph_type}")

        if subG.number_of_nodes() <2:
            continue

        full_nodelist = list(G_full.nodes())
        full_pos = {u: t for t, u in enumerate(full_nodelist)}
        sub_nodelist = list(bfs_idx)
        mapping = {u: k for k, u in enumerate(sub_nodelist)}
        subG_local = nx.relabel_nodes(subG, mapping, copy=True)

        C_noisy_full, _, _ = nx_graph_sparse_random_noise(G_full, noise_ratio, seed=seed + i)
        G_noise = nx.from_numpy_array(C_noisy_full)

        gt_idx = np.asarray([full_pos[u] for u in sub_nodelist], dtype=int)
        src_graphs.append(subG_local)
        tgt_graphs.append(G_noise)
        gt_list.append(gt_idx)

    return src_graphs, tgt_graphs, gt_list

graphs = [G for G in graphs if G.number_of_nodes() >= 2]
src_graphs, tgt_graphs, gt_list = build_subgraph_noisy_pairs(
    graphs, noise_ratio=noise_ratio, seed=seed, use_full_data=use_full_data,
    subgraph_ratio=subgraph_ratio, subgraph_type='connected'
)

total_num_graphs = len(src_graphs)
print('total_num_graphs: ', total_num_graphs)
diff_times = []

results, times = defaultdict(list), defaultdict(list)
edge_results = defaultdict(list)

for j in tqdm(range(total_num_graphs), desc="Overall Progress"):  # total_num_graphs
    print('graph id: ', j)
    G = src_graphs[j]
    G_noise = tgt_graphs[j]

    G_adj = nx.to_numpy_array(G).astype(np.float32)
    G_adj_noise = nx.to_numpy_array(G_noise).astype(np.float32)
    edges_src = list(G.edges())
    edges_tgt = list(G_noise.edges())
    m, n = G_adj.shape[0], G_adj_noise.shape[0]
    if m <= 1:
        continue
    idx = gt_list[j]
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    Xinit = p @ q.T
    ################################################Mask-Guided GW###########################
    G_adj_gpu = torch.tensor(G_adj).cuda()
    G_adj_noise_gpu = torch.tensor(G_adj_noise).cuda()
    rho=1e-2

    start = time.time()
    coup_mggw, alpha_, beta_, obj_list_mggw = mask_guided_gw_solver(A=G_adj_gpu, B=G_adj_noise_gpu, a=None, b=None,
                                                            S_alpha=alphas, S_beta=alphas,
                                                            X_init=None, alpha_init=None, beta_init=None,
                                                            outer_epochs=1000, inner_steps_X=1,  # Increased default for inner steps
                                                            inner_loop_eps=1e-5,  # <<< NEW: Early stopping tolerance for X
                                                            eps=1e-3, rho=rho, min_rho=rho, scaling
                                                            =1.0, early_stop_patience=5,plot=False,graph_id=j)
    end = time.time()
    time_ = end - start
    method_name = f'Mask-Guided GW_{alphas}'
    acc = node_correctness2(coup_mggw.cpu().numpy()-alpha_.cpu().numpy(), idx)
    edge_acc = edge_accuracy_soft_greedy(coup_mggw, edges_src, edges_tgt, idx, alpha_, beta_)
    
    times[method_name].append(time_)
    results[method_name].append(acc)
    edge_results[method_name].append(edge_acc)   
    print(f'{method_name}: NC={acc:.4f}, EC={edge_acc:.4f}, Time={time_:.4f}')


    ######FW###########################################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    # t = 10

    start = time.time()
    coup, log = ot.gromov.gromov_wasserstein(G_adj, G_adj_noise, p.squeeze(), q.squeeze(),
                                             loss_fun='kl_loss', log=True, max_iter=500)
    end = time.time()

    times['FW'].append(end - start)
    results['FW'].append(node_correctness2(coup, idx))
    edge_results['FW'].append(edge_accuracy_soft_greedy(coup, edges_src, edges_tgt, idx, 0, 0))

    ######BAPG###########################################################################################################################
    start = time.time()
    rho = 0.1
    coup_bap, obj_list_bap = BAPG_numpy(A=G_adj, B=G_adj_noise, a=p, b=q, X=Xinit, epoch=500, eps=1e-6,
                                        rho=rho)
    end = time.time()
    times['BAPGcpu'].append(end - start)
    results['BAPGcpu'].append(node_correctness2(coup_bap, idx))
    edge_results['BAPGcpu'].append(edge_accuracy_soft_greedy(coup_bap, edges_src, edges_tgt, idx, 0, 0))

    # #######BPG##############################################################################################################################
    ot_hyperpara_adj = {'loss_type': 'L2',  # the key hyperparameters of GW distance
                        'ot_method': 'proximal',
                        'beta': 0.2,  #
                        'outer_iteration': 200,
                        'iter_bound': 1e-10,
                        'inner_iteration': 500,  # origin: 1, BPG:500
                        'sk_bound': 1e-5,  # origin: 1e-30, BPG:1e-5
                        'node_prior': 0,
                        'max_iter': 500,  # iteration and error bound for calculating barycenter
                        'cost_bound': 1e-16,
                        'update_p': False,  # optional updates of source distribution
                        'lr': 0,
                        'alpha': 0}
    start = time.time()
    coup_adj, d_gw, p_s = gromov_wasserstein_discrepancy(G_adj, G_adj_noise, p, q, ot_hyperpara_adj)
    end = time.time()
    times['BPG'].append(end - start)
    results['BPG'].append(node_correctness2(coup_adj, idx))
    edge_results['BPG'].append(edge_accuracy_soft_greedy(coup_adj, edges_src, edges_tgt, idx, 0, 0))

    ######eBPG##############################################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    # Reddit: 1e-1 Other: 1e-2
    start = time.time()
    coup_adj, _ = ot.gromov.entropic_gromov_wasserstein(G_adj, G_adj_noise, p.squeeze(-1), q.squeeze(-1),
                                                        loss_fun='square_loss', epsilon=1e-1,
                                                        verbose=False, log=True, max_iter=500)
    end = time.time()
    times['eBPG'].append(end - start)
    results['eBPG'].append(node_correctness2(coup_adj, idx))
    edge_results['eBPG'].append(edge_accuracy_soft_greedy(coup_adj, edges_src, edges_tgt, idx, 0, 0))

    #########SpecGWL#######################################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    t = 10
    start = time.time()
    G_hk = sgw.undirected_normalized_heat_kernel(G, t)
    G_hk_noise = sgw.undirected_normalized_heat_kernel(G_noise, t)
    start2 = time.time()
    coup_hk, log_hk = ot.gromov.gromov_wasserstein(G_hk, G_hk_noise, p.squeeze(), q.squeeze(),
                                                   loss_fun='square_loss', log=True)
    end = time.time()

    times['SpecGWL'].append(end - start)
    diff_times.append(end - start2)
    results['SpecGWL'].append(node_correctness2(coup_hk, idx))
    edge_results['SpecGWL'].append(edge_accuracy_soft_greedy(coup_hk, edges_src, edges_tgt, idx, 0, 0))

    ################srGW###############################################################################################################
    a = (np.ones(m) / m)
    b = (np.ones(n) / n)
    start = time.time()
    G_adj_double = G_adj.astype('double')
    G_adj_noise_double = G_adj_noise.astype('double')
    CX, CY = torch.from_numpy(G_adj_double).float(), torch.from_numpy(G_adj_noise_double).float()
    A, B = torch.from_numpy(a).float(), torch.from_numpy(b).float()
    # mirror descent 
    start = time.time()
    coup_srGW_md, _ = md_semirelaxed_gromov_wasserstein(C1=CX, p=A, C2=CY,gamma_entropy=2.0,eps=1e-6)
    end = time.time()
    coup_srGW_md = coup_srGW_md.cpu().data.numpy()
    times['srGW_md'].append(end - start)
    results['srGW_md'].append(node_correctness2(coup_srGW_md, idx))
    edge_results['srGW_md'].append(edge_accuracy_soft_greedy(coup_srGW_md, edges_src, edges_tgt, idx, 0, 0))

    ###########Robust GW###############################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    start = time.time()
    coup_rgw, obj_list_rgw, alpha, beta = robust_gw(Ds=G_adj, Dt=G_adj_noise, a=p, b=q, PALM_maxiter=5000,
                                                    rho1=0.05, rho2=0.5, eta=0.05,
                                                    t1=0.1, t2=0.1,
                                                    tau1=1, tau2=0.5, relative_error=1e-6)
    end = time.time()
    times['rgw'].append(end - start)
    results['rgw'].append(node_correctness2(coup_rgw, idx))
    edge_results['rgw'].append(edge_accuracy_soft_greedy(coup_rgw, edges_src, edges_tgt, idx, 0, 0))

    ###########Unbalanced GW###############################################################################################################
    a = np.ones(m) / m
    b = np.ones(n) / n
    G_adj_double = G_adj.astype('double')
    G_adj_noise_double = G_adj_noise.astype('double')
    CX, CY = torch.from_numpy(G_adj_double), torch.from_numpy(G_adj_noise_double)
    A, B = torch.from_numpy(a), torch.from_numpy(b)
    rho = 0.05
    eps = 0.001
    PI = None

    start = time.time()
    PI = log_ugw_sinkhorn(
        A,
        CX,
        B,
        CY,
        init=PI,
        eps=eps,
        rho=rho,
        rho2=rho,
        nits_plan=500,
        tol_plan=1e-6,
        nits_sinkhorn=500,
        tol_sinkhorn=1e-4,
    )
    end = time.time()
    coup_ugw = PI.cpu().data.numpy()
    times['ugw'].append(end - start)
    results['ugw'].append(node_correctness2(coup_ugw, idx))
    edge_results['ugw'].append(edge_accuracy_soft_greedy(coup_ugw, edges_src, edges_tgt, idx, 0, 0))

    ###########Partial GW###############################################################################################################
    p = np.ones([m, 1]) / m
    q = np.ones([n, 1]) / n
    ratio = m / n
    start = time.time()
    # coup_partial = pu_gw_emd(C1=G_adj_noise, C2=G_adj, p=q.ravel(), q=p.ravel(), nb_dummies=1, G0=None, log=False,
    #                          max_iter=1000)
    coup_partial = ot.partial.partial_gromov_wasserstein(C1=G_adj_noise, C2=G_adj, p=q.ravel(), q=p.ravel(),
                                                         m=0.9, log=False, numItermax=500)
    end = time.time()
    times['pgw'].append(end - start)
    results['pgw'].append(node_correctness2(coup_partial.T, idx))
    edge_results['pgw'].append(edge_accuracy_soft_greedy(coup_partial.T, edges_src, edges_tgt, idx, 0, 0))
    for method, result in results.items():
            if len(results[method]):
                print(
                    'method: {} NC: {:.2f} EC: {:.2f}'.format(
                        method, results[method][-1], edge_results[method][-1]
                    )
                )
print('---------------------------------Completed---------------------------------------')
from datetime import datetime
TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
os.makedirs("expall", exist_ok=True)
output_filename = f"exp_edge/final_stats_{database}_subratio_{noise_ratio}_{TIMESTAMP}.txt"

for method, result in results.items():
    with open(output_filename, 'a+') as f:
        f.write(
            'Method: {} NC Mean: {:.4f}, NC Std: {:.4f}, EC Mean: {:.4f}, EC Std: {:.4f}, Time: {:.4f}\n'.format(
                method,
                np.mean(results[method]),
                np.std(results[method]),
                np.mean(edge_results[method]),
                np.std(edge_results[method]),
                np.sum(times[method]),
            )
        )

    print(
        'Method: {} NC Mean: {:.4f}, NC Std: {:.4f}, EC Mean: {:.4f}, EC Std: {:.4f}, Time: {:.4f}'.format(
            method,
            np.mean(results[method]),
            np.std(results[method]),
            np.mean(edge_results[method]),
            np.std(edge_results[method]),
            np.sum(times[method]),
        )
    )
