import time
import torch
from torch import nn
import numpy as np
import pickle
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
from joblib import Parallel, delayed
from tqdm import tqdm

# ALPHA = 0.0


def cost_mat(cost_s, cost_t, p_s, p_t, tran, emb_s, emb_t, alpha, device):
    # print(cost_s.shape, cost_t.shape, p_s.shape, p_t.shape, tran.shape, emb_s.shape, emb_t.shape)
    f1_st = ((cost_s**2) @ p_s).repeat(1, tran.size(1))
    f2_st = (torch.t(p_t) @ torch.t(cost_t**2)).repeat(tran.size(0), 1)
    cost_st = f1_st + f2_st
    cost1 = (cost_st - 2 * cost_s @ tran @ torch.t(cost_t))

    if emb_s is not None and emb_t is not None:
        #tmp1 = emb_s @ torch.t(emb_t)
        #tmp2 = torch.sqrt((emb_s ** 2) @ torch.ones(emb_s.size(1), 1))
        #tmp3 = torch.sqrt((emb_t ** 2) @ torch.ones(emb_t.size(1), 1))
        #cost2 = 0.5 * (1 - tmp1 / (tmp2 @ torch.t(tmp3)))
        #cost1 * (1 - alpha) + alpha*cost2
        tmp1 = 2 * emb_s @ torch.t(emb_t)
        tmp2 = ((emb_s ** 2) @ torch.ones(emb_s.size(1), 1, device=cost_s.device)).repeat(1, tran.size(1))
        tmp3 = ((emb_t ** 2) @ torch.ones(emb_t.size(1), 1, device=cost_s.device)).repeat(1, tran.size(0))

        cost2 = (tmp2 + torch.t(tmp3) - tmp1) / (emb_s.size(1)**2)

        return cost1 * (1 - alpha) + alpha * cost2
    #exit(1)
    return cost1


def ot_fgw(cost_s, cost_t, p_s, p_t, ot_method, gamma, num_layer, emb_s, emb_t, alpha, device):
    #tran = p_s @ torch.t(p_t)
    tran = torch.abs(torch.randn(cost_s.shape[0],cost_t.shape[0], device=device))
    tran /= torch.sum(tran)
    # print("init tran shape: ", tran.shape)
    if ot_method == 'ppa':
        dual = torch.ones(p_s.size(), device=device) / p_s.size(0)
        for m in range(num_layer):
            cost = cost_mat(cost_s, cost_t, p_s, p_t, tran, emb_s, emb_t, alpha, device)
            # cost /= torch.max(cost)
            kernel = (torch.exp(-cost / gamma) * tran)
            b = p_t / (torch.t(kernel) @ dual)
            for i in range(5):
                dual = p_s / (kernel @ b)
                b = p_t / (torch.t(kernel) @ dual)
            tran = (dual @ torch.t(b)) * kernel
            d_gw = (cost_mat(cost_s, cost_t, p_s, p_t, tran, emb_s, emb_t, alpha, device) * tran).sum()
            # print(m, kernel, d_gw)
    elif ot_method == 'b-admm':
        all1_s = torch.ones(p_s.size())
        all1_t = torch.ones(p_t.size())
        dual = torch.zeros(p_s.size(0), p_t.size(0))
        for m in range(num_layer):
            kernel_a = torch.exp((dual + 2 * torch.t(cost_s) @ tran @ cost_t) / gamma) * tran
            b = p_t / (torch.t(kernel_a) @ all1_s)
            aux = (all1_s @ torch.t(b)) * kernel_a
            dual = dual + gamma * (tran - aux)
            cost = cost_mat(cost_s, cost_t, p_s, p_t, aux, emb_s, emb_t, alpha)
            # cost /= torch.max(cost)
            kernel_t = torch.exp(-(cost + dual) / gamma) * aux
            a = p_s / (kernel_t @ all1_t)
            tran = (a @ torch.t(all1_t)) * kernel_t

    elif ot_method == 'sinkhorn':
        for _ in range(num_layer):
            time1 = time.time()
            fused_cost = cost_mat(cost_s, cost_t, p_s, p_t, tran, emb_s, emb_t, alpha, device)
            time2 = time.time()
            # print("compute cost time: ", time2-time1)
            K = torch.exp(-fused_cost / gamma)
            u = torch.ones_like(p_s)
            v = torch.ones_like(p_t)

            for _ in range(10):  # inner Sinkhorn loop
                u = p_s / (K @ v + 1e-9)
                v = p_t / (K.T @ u + 1e-9)
            
            tran = torch.diagflat(u) @ K @ torch.diagflat(v)
            time3 = time.time()
            # print("compute iteration time: ", time3-time2)
    d_gw = (cost_mat(cost_s, cost_t, p_s, p_t, tran, emb_s, emb_t, alpha, device) * tran).sum()
    return d_gw, tran




def fgw_contribution(graph_k, graph_b, p_k, p_b, emb_k, emb_b, gamma, alpha, device):
    # print(graph_k.shape, graph_b.shape, p_k.shape, emb_k.shape, emb_b.shape)
    d_gw, tran_k = ot_fgw(
        cost_s=graph_k,
        cost_t=graph_b,
        p_s=p_k,
        p_t=p_b,
        ot_method="ppa",
        gamma=gamma,
        num_layer=3,
        emb_s=emb_k,
        emb_t=emb_b,
        alpha=alpha,
        device=device
    )
    if torch.isnan(tran_k).any():
        raise ValueError("NaN detected in transport plan")

    graph_contrib = tran_k.T @ graph_k @ tran_k
    emb_contrib = tran_k.T @ emb_k if emb_k is not None else None
    return graph_contrib, emb_contrib



def barycenter(graphs, init_barycenter, dim_embedding, ot_method, gamma, gwb_layers, ot_layers, alpha, device):
    graph_b = init_barycenter[0].detach().cpu()
    p_b = init_barycenter[1].detach().cpu()
    emb_b = init_barycenter[2].detach().cpu() if init_barycenter[2] is not None else None

    tmp1 = p_b @ p_b.T
    tmp2 = p_b @ torch.ones(1, dim_embedding, device=device)
    weight = 1 / len(graphs)

    noembedding = graphs[0][2] is None
    # device_str = str(device)

    # prepare inputs
    tasks = [
        (g[0], g[1], g[2] if g[2] is not None else None)
        for g in graphs
    ]
    
    results = Parallel(n_jobs=-1)(
        delayed(fgw_contribution)(
            graph_k, graph_b, p_k, p_b, emb_k, emb_b, gamma, alpha, device
        ) for graph_k, p_k, emb_k in tasks
    )

    # print(results)

    graph_b_tmp = sum([r[0] for r in results]) * weight
    if noembedding:
        return [graph_b_tmp / tmp1, p_b, None]
    else:
        emb_b_tmp = sum([r[1] for r in results]) * weight
        return [graph_b_tmp / tmp1, p_b, emb_b_tmp / tmp2]

def compute_fgw_for_graph(i, graph_i, centers, ps, embeddings, gamma, alpha, device):
    d_gws = []
    for k in range(len(centers)):
        d_gw, _ = ot_fgw(
            cost_s=graph_i[0], cost_t=centers[k],
            p_s=graph_i[1], p_t=ps[k],
            ot_method="ppa", gamma=gamma, num_layer=3,
            emb_s=graph_i[2], emb_t=embeddings[k],
            alpha=alpha, device=device
        )
        d_gws.append(d_gw.item())
    return i, d_gws

def k_barycenter(graphs, size_centers, dim_embedding, ot_method, gamma, gwb_layers, ot_layers, n_iters, filename, alpha, device):
    num_centers = len(size_centers)
    ps, centers, embeddings = [], [], []

    # Initialize centers
    for k in range(num_centers):
        center = torch.sigmoid(torch.randn(size_centers[k], size_centers[k])).to(device)
        embedding = nn.functional.normalize(
            torch.sigmoid(torch.randn(size_centers[k], dim_embedding)), p=1
        ).to(device)
        dist = torch.ones(size_centers[k], 1).to(device) / size_centers[k]
        ps.append(dist)
        centers.append(center)
        embeddings.append(embedding)

    for it in range(n_iters):
        time_inner = time.time()
        total_distance = 0.0
        center2sample = {}

        # Parallel FGW computation
        results = Parallel(n_jobs=-1)(
            delayed(compute_fgw_for_graph)(i, graphs[i], centers, ps, embeddings, gamma, alpha, device)
            for i in tqdm(range(len(graphs)), desc=f"FGW Assignment Iter {it}")
        )

        for i, d_gws in results:
            d_gws = np.array(d_gws)
            min_idx = np.argmin(d_gws)
            if np.isnan(d_gws[min_idx]):
                print("NaN")
                continue
            total_distance += d_gws[min_idx]
            center2sample.setdefault(min_idx, []).append(i)

        time_bary1 = time.time()
        print("bary first part time:", time_bary1 - time_inner)

        # Update barycenters
        for k in range(num_centers):
            if k not in center2sample:
                continue
            samples_k = [graphs[i] for i in center2sample[k]]
            graph_b, p_b, emb_b = barycenter(
                samples_k, [centers[k], ps[k], embeddings[k]],
                dim_embedding, ot_method, gamma, gwb_layers, ot_layers, alpha, device
            )
            centers[k] = graph_b
            ps[k] = p_b
            embeddings[k] = emb_b

        time_bary2 = time.time()
        print("bary second part time:", time_bary2 - time_bary1)

    # Save results
    s_centers = [c.cpu().tolist() for c in centers]
    s_ps = [p.cpu().tolist() for p in ps]
    s_embeddings = [e.cpu().tolist() if e is not None else None for e in embeddings]

    with open(filename, 'wb') as f:
        pickle.dump([s_centers, s_ps, s_embeddings], f)

    return s_centers, s_ps, s_embeddings
        

def fusedGW_featurize(graphs, centers, ps, embeddings, alpha, device):
    K = len(centers)
    G2Fs =[]
    for i, graph in enumerate(graphs):
        center_i, p_i, emb_i, label = graph[0], graph[1], graph[2], graph[3]
        
        W_feats =[]
        GW_feats =[]

        for k in range(K):
            center_k = centers[k].to(device)
            p_k = ps[k].to(device)
            sig_k = center_k.shape[0]
            emb_k = embeddings[k].to(device)
            _, tran_ki= ot_fgw(cost_s =center_k, cost_t =center_i, p_s=p_k, p_t=p_i, ot_method="ppa", gamma=0.1, num_layer=3, emb_s=emb_k, emb_t=emb_i, alpha=alpha, device=device)

            if emb_i is not None:
                W_feat = sig_k * tran_ki @ emb_i
            else:
                W_feat = None
            W_feats.append(W_feat)
            GW_feat = sig_k * sig_k * tran_ki @ center_i @ torch.t(tran_ki)
            GW_feats.append(GW_feat)
            
        G2Fs.append([W_feats, GW_feats, label])
    return G2Fs