import time
import ot
import os
from BAPG import *
from robust_gw import *
from unbalancedgw.vanilla_ugw_solver import log_ugw_sinkhorn
from collections import defaultdict
import pickle
import warnings
from partial_gw import pu_gw_emd
from GromovWassersteinFramework import *
import GromovWassersteinGraphToolkit as GwGt
from gromovWassersteinAveraging import *
import spectralGW as sgw
from srGW import *
from tqdm import tqdm 
from MGGW import mask_guided_gw_solver

warnings.filterwarnings("ignore")
seed = 321
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

database = 'bzr'  # proteins / enzymes / powerlaw
subgraph_ratio = 0.1
subgrpah_type = 'connected'

graphs, new_graphs = [], []
ans = []
alphas = 0.9


if database == 'enzymes':
    print('------------------Node Matching on ENZYMES---------------')
    with open('data/ENZYMES_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)
        
if database == 'bzr':
    print('------------------Node Matching on BZR---------------')
    with open('data/BZR_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)
        
if database == 'Fingerprint':
    print('------------------Node Matching on Fingerprint---------------')
    with open('data/Fingerprint_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)

if database == 'dhfr':
    print('------------------Node Matching on DHFR---------------')
    with open('data/DHFR_weighted.pkl', 'rb') as f:
        graphs = pickle.load(f)



filtered_graphs = []
for G in graphs:
    if G.number_of_nodes() < 2:
        continue
    if subgrpah_type == 'connected':
        subG, idx = connected_subgraph(G, subgraph_ratio)
    elif subgrpah_type == 'random':
        subG, idx = random_subgraph(G, subgraph_ratio)
    if subG.number_of_nodes() < 2:
        continue
    filtered_graphs.append(G)
    new_graphs.append(subG)
    ans.append(idx)

graphs = filtered_graphs

total_num_graphs = len(graphs)
print('total_num_graphs: ', total_num_graphs)
diff_times = []

results, times = defaultdict(list), defaultdict(list)

for j in tqdm(range(total_num_graphs), desc="Overall Progress"):  # total_num_graphs
    print('graph id: ', j)
    G = new_graphs[j]
    G_noise = graphs[j]
    idx = ans[j]

    G_adj = nx.to_numpy_array(G).astype(np.float32)
    G_adj_noise = nx.to_numpy_array(G_noise).astype(np.float32)
    m, n = G_adj.shape[0], G_adj_noise.shape[0]
    if m <= 1:
        continue
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    Xinit = p @ q.T
    ################################################Mask-Guided GW###########################
    G_adj_gpu = torch.tensor(G_adj).cuda()
    G_adj_noise_gpu = torch.tensor(G_adj_noise).cuda()
    bestacc = -100
    rho=1e-2

    start = time.time()
    coup_mggw, alpha_, beta_, obj_list_mggw = mask_guided_gw_solver(A=G_adj_gpu, B=G_adj_noise_gpu, a=None, b=None,
                                                            S_alpha=alphas, S_beta=alphas,
                                                            X_init=None, alpha_init=None, beta_init=None,
                                                            outer_epochs=1000, inner_steps_X=1,  # Increased default for inner steps
                                                            inner_loop_eps=1e-5,  # <<< NEW: Early stopping tolerance for X
                                                            eps=1e-1, rho=rho, min_rho=rho, scaling
                                                            =1.0, early_stop_patience=5,plot=False,graph_id=j)
    end = time.time()
    time_ = end - start
    bestacc = node_correctness2(coup_mggw.cpu().numpy()-alpha_.cpu().numpy(), idx )
    best_time = time_
    best_alpha = alphas
    end = time.time()
    times['Mask-Guided GW'].append(best_time)
    results['Mask-Guided GW'].append(bestacc)

    ######FW###########################################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    # t = 10

    start = time.time()
    coup, log = ot.gromov.gromov_wasserstein(G_adj, G_adj_noise, p.squeeze(), q.squeeze(),
                                             loss_fun='kl_loss', log=True, max_iter=500)
    end = time.time()

    times['FW'].append(end - start)
    results['FW'].append(node_correctness2(coup, idx))

    ######BAPG###########################################################################################################################
    start = time.time()
    rho = 0.1
    coup_bap, obj_list_bap = BAPG_numpy(A=G_adj, B=G_adj_noise, a=p, b=q, X=Xinit, epoch=500, eps=1e-6,
                                        rho=rho)
    end = time.time()
    times['BAPGcpu'].append(end - start)
    results['BAPGcpu'].append(node_correctness2(coup_bap, idx))

    # #######BPG##############################################################################################################################
    ot_hyperpara_adj = {'loss_type': 'L2',  # the key hyperparameters of GW distance
                        'ot_method': 'proximal',
                        'beta': 0.2,  #
                        'outer_iteration': 200,
                        'iter_bound': 1e-10,
                        'inner_iteration': 500,  # origin: 1, BPG:500
                        'sk_bound': 1e-5,  # origin: 1e-30, BPG:1e-5
                        'node_prior': 0,
                        'max_iter': 500,  # iteration and error bound for calculating barycenter
                        'cost_bound': 1e-16,
                        'update_p': False,  # optional updates of source distribution
                        'lr': 0,
                        'alpha': 0}
    start = time.time()
    coup_adj, d_gw, p_s = gromov_wasserstein_discrepancy(G_adj, G_adj_noise, p, q, ot_hyperpara_adj)
    end = time.time()
    times['BPG'].append(end - start)
    results['BPG'].append(node_correctness2(coup_adj, idx))

    ######eBPG##############################################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    # Reddit: 1e-1 Other: 1e-2
    start = time.time()
    coup_adj, _ = ot.gromov.entropic_gromov_wasserstein(G_adj, G_adj_noise, p.squeeze(-1), q.squeeze(-1),
                                                        loss_fun='square_loss', epsilon=1e-1,
                                                        verbose=False, log=True, max_iter=500)
    end = time.time()
    times['eBPG'].append(end - start)
    results['eBPG'].append(node_correctness2(coup_adj, idx))

    #########SpecGWL#######################################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    t = 10
    start = time.time()
    G_hk = sgw.undirected_normalized_heat_kernel(G, t)
    G_hk_noise = sgw.undirected_normalized_heat_kernel(G_noise, t)
    start2 = time.time()
    coup_hk, log_hk = ot.gromov.gromov_wasserstein(G_hk, G_hk_noise, p.squeeze(), q.squeeze(),
                                                   loss_fun='square_loss', log=True)
    end = time.time()

    times['SpecGWL'].append(end - start)
    diff_times.append(end - start2)
    results['SpecGWL'].append(node_correctness2(coup_hk, idx))

    ################srGW###############################################################################################################
    a = (np.ones(m) / m)
    b = (np.ones(n) / n)
    start = time.time()
    G_adj_double = G_adj.astype('double')
    G_adj_noise_double = G_adj_noise.astype('double')
    CX, CY = torch.from_numpy(G_adj_double).float(), torch.from_numpy(G_adj_noise_double).float()
    A, B = torch.from_numpy(a).float(), torch.from_numpy(b).float()
    # mirror descent 
    start = time.time()
    coup_srGW_md, _ = md_semirelaxed_gromov_wasserstein(C1=CX, p=A, C2=CY,gamma_entropy=2.0,eps=1e-6)
    end = time.time()
    coup_srGW_md = coup_srGW_md.cpu().data.numpy()
    times['srGW_md'].append(end - start)
    results['srGW_md'].append(node_correctness2(coup_srGW_md, idx))

    ###########Robust GW###############################################################################################################
    p = np.ones([m, 1]).astype(np.float32) / m
    q = np.ones([n, 1]).astype(np.float32) / n
    start = time.time()
    coup_rgw, obj_list_rgw, alpha, beta = robust_gw(Ds=G_adj, Dt=G_adj_noise, a=p, b=q, PALM_maxiter=5000,
                                                    rho1=0.05, rho2=0.5, eta=0.05,
                                                    t1=0.1, t2=0.1,
                                                    tau1=1, tau2=0.5, relative_error=1e-6)
    end = time.time()
    times['rgw'].append(end - start)
    results['rgw'].append(node_correctness2(coup_rgw, idx))

    ###########Unbalanced GW###############################################################################################################
    a = np.ones(m) / m
    b = np.ones(n) / n
    G_adj_double = G_adj.astype('double')
    G_adj_noise_double = G_adj_noise.astype('double')
    CX, CY = torch.from_numpy(G_adj_double), torch.from_numpy(G_adj_noise_double)
    A, B = torch.from_numpy(a), torch.from_numpy(b)
    rho = 0.05
    eps = 0.001
    PI = None

    start = time.time()
    PI = log_ugw_sinkhorn(
        A,
        CX,
        B,
        CY,
        init=PI,
        eps=eps,
        rho=rho,
        rho2=rho,
        nits_plan=500,
        tol_plan=1e-6,
        nits_sinkhorn=500,
        tol_sinkhorn=1e-4,
    )
    end = time.time()
    coup_ugw = PI.cpu().data.numpy()
    times['ugw'].append(end - start)
    results['ugw'].append(node_correctness2(coup_ugw, idx))

    ###########Partial GW###############################################################################################################
    p = np.ones([m, 1]) / m
    q = np.ones([n, 1]) / n
    ratio = m / n
    start = time.time()
    # coup_partial = pu_gw_emd(C1=G_adj_noise, C2=G_adj, p=q.ravel(), q=p.ravel(), nb_dummies=1, G0=None, log=False,
    #                          max_iter=1000)
    coup_partial = ot.partial.partial_gromov_wasserstein(C1=G_adj_noise, C2=G_adj, p=q.ravel(), q=p.ravel(),
                                                         m=0.9, log=False, numItermax=500)
    end = time.time()
    times['pgw'].append(end - start)
    results['pgw'].append(node_correctness2(coup_partial.T, idx))
    for method, result in results.items():
            if len(results[method]):
                print('method: {} NC: {:.2f} '.format(method, results[method][-1]))
print('---------------------------------Completed---------------------------------------')
from datetime import datetime
TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
os.makedirs("expall", exist_ok=True)
output_filename = f"expall/final_stats_{database}_subratio_{subgraph_ratio}_{TIMESTAMP}.txt"

for method, result in results.items():
    with open(output_filename, 'a+') as f:
        f.write('Method: {} Mean: {:.4f}, Std: {:.4f}, Time: {:.4f}\n'.format(method,
                                                                      np.mean(
                                                                          results[method]),
                                                                      np.std(
                                                                          results[method]),
                                                                      np.sum(times[method])))

    print('Method: {} Mean: {:.4f}, Std: {:.4f}, Time: {:.4f}'.format(method,
                                                                      np.mean(
                                                                          results[method]),
                                                                      np.std(
                                                                          results[method]),
                                                                      np.sum(times[method])))
