import torch
import matplotlib.pyplot as plt

def laplacian_from_adjacency(A):
    # A: (B, N, N)
    return torch.diag_embed(A.sum(dim=-1)) - A

def adjacency_from_laplacian(L):
    A = -L.clone()
    A.diagonal(dim1=-2, dim2=-1).zero_()
    return A

def matrix_sqrt(M, eps=1e-6):
    eigvals, eigvecs = torch.linalg.eigh(0.5 * (M + M.transpose(-2, -1)))
    # clamp both from below and (optionally) from above
    eigvals = torch.clamp(eigvals, min=eps)
    return eigvecs @ torch.diag_embed(torch.sqrt(eigvals)) @ eigvecs.transpose(-2, -1)

def bw_mean_laplacian(L0, L1, t, num_iter=None):
    """
    L0, L1: (B, N, N)
    t:      (B, 1, 1)
    """
    L0_inv = torch.linalg.pinv(L0)
    L1_inv = torch.linalg.pinv(L1)

    L0_sqrt = matrix_sqrt(L0)
    L0_inv_sqrt = matrix_sqrt(L0_inv)

    middle = L0_inv_sqrt @ L1_inv @ L0_inv_sqrt
    middle_sqrt = matrix_sqrt(middle)

    inv_mean = (1 - t) * L0_inv + t * middle_sqrt
    S_t = L0_sqrt @ (inv_mean @ inv_mean) @ L0_sqrt
    return torch.linalg.pinv(S_t)

def arithmetic_mean_adjacency(A0, A1, t, eps=1e-3):
    return (1 - t) * A0 + t * A1


def harmonic_mean_adjacency(A0, A1, t, eps=1e-3):
    inv = (1 - t) / (A0 + eps) + t / (A1 + eps)
    H = 1.0 / (inv + eps)
    return H

def geometric_mean_adjacency(A0, A1, t, eps=1e-3):
    return torch.pow(A0 + eps, 1 - t) * torch.pow(A1 + eps, t)


def bures_wasserstein_mean_adjacency(A0, A1, t, eps=1e-3, num_iter=None):
    """
    A0, A1: (B, N, N)
    t:      (B,) or (B, 1, 1)
    returns symmetric binary adjacency: (B, N, N)
    """
    B, N, _ = A0.shape
    device, dtype = A0.device, A0.dtype

    # Ensure shape of t is (B,1,1)
    t = t.view(B, 1, 1)

    L0 = laplacian_from_adjacency(A0)  # (B, N, N)
    L1 = laplacian_from_adjacency(A1)

    # enforce symmetry
    L0 = 0.5 * (L0 + L0.transpose(-1, -2))
    L1 = 0.5 * (L1 + L1.transpose(-1, -2))

    # stronger regularisation than 1e-6 for large graphs
    lap_eps = 1e-3  # or even 1e-2 depending on N
    I = torch.eye(N, device=device, dtype=dtype).unsqueeze(0)
    L0 = L0 + lap_eps * I
    L1 = L1 + lap_eps * I

    L_mean = bw_mean_laplacian(L0, L1, t, num_iter)
    A_mean = adjacency_from_laplacian(L_mean)
    return A_mean


class MeanMetric:
    def __init__(self, cfg, threshold=0.5, eps=1e-3):
        self.threshold = threshold
        self.eps = eps
        self.mix_rate = cfg.model.mix_rate
        self.iter_solve = cfg.model.iterative_solve
        if cfg.model.mean_metric == 'bw':
            print("mean func constructed")
            self.mean_fn = bures_wasserstein_mean_adjacency
        elif cfg.model.mean_metric == 'arithmetic':
            self.mean_fn = arithmetic_mean_adjacency
        elif cfg.model.mean_metric == 'harmonic':
            self.mean_fn = harmonic_mean_adjacency
        elif cfg.model.mean_metric == 'geometric':
            self.mean_fn = geometric_mean_adjacency
        else:
            raise ValueError(f"Unknown model type: {cfg.model.mean_metric}")

    def mean_fn_batched(self, A0_batch, A1_batch, t_float, ret_prob=False):
        """
        Batched Bures-Wasserstein mean for adjacency matrices on GPU.
        """
        # Ensure float64 for precision
        A0 = torch.clamp(A0_batch.to(torch.float64), 0.0, 1.0)
        A1 = torch.clamp(A1_batch.to(torch.float64), 0.0, 1.0)

        # enforce symmetry
        A0 = 0.5 * (A0 + A0.transpose(-1, -2))
        A1 = 0.5 * (A1 + A1.transpose(-1, -2))

        if t_float.ndim in [1, 2]:
            t_float = t_float.view(-1, 1, 1)  # (B, 1, 1)
        if self.iter_solve == 0:
            A_mean_cont = self.mean_fn(A0, A1, t_float, self.eps)
        else:
            A_mean_cont = self.mean_fn(A0, A1, t_float, self.eps, num_iter=self.iter_solve)

        if ret_prob:
            return torch.clamp(A_mean_cont, min=0., max=1.)

        if not self.mix_rate:
            A_sym = (A_mean_cont > self.threshold).int()
        else:
            A_mean_cont = torch.clamp(A_mean_cont, min=0., max=1.)
            A_sym = torch.bernoulli(A_mean_cont).int()

        A_sym = torch.triu(A_sym, diagonal=1)
        return A_sym + A_sym.transpose(-2, -1)



class BWVelocity:
    def __init__(self, cfg, threshold=0.5, eps=1e-6):
        self.threshold = threshold
        self.eps = eps

    def compute_T(self, L0, L1, L0_pinv):
        L0_sqrt = matrix_sqrt(L0)  # (B, N, N)

        L0_pinv_sqrt = matrix_sqrt(L0_pinv)
        L1_pinv = torch.linalg.pinv(L1)
        inner = matrix_sqrt(
            L0_pinv_sqrt @ L1_pinv @ L0_pinv_sqrt
        )  # (B, N, N)

        T = L0_sqrt @ inner @ L0_sqrt  # (B, N, N)

        return T

    def compute_velocity(self,
                         t,
                         X_1_onehot,
                         E_1_label,
                         X_t_onehot,
                         E_t_label,
                         E_0_label
                         ):
        B, N, _ = X_1_onehot.shape
        R_t_X = (X_1_onehot - X_t_onehot) / (1.0 - t)

        E1 = E_1_label.squeeze(-1)
        E0 = E_0_label.squeeze(-1)
        Et = E_t_label.squeeze(-1)

        dA_dt, At = self.compute_velocity_E(t, E1, E0)
        R_t_E = (1-2 * Et) * dA_dt / (At * (1-At))
        R_t_E = torch.stack([-R_t_E, R_t_E], dim=-1)

        R_t_X = torch.nan_to_num(R_t_X, nan=0.0, posinf=0.0, neginf=0.0)
        R_t_E = torch.nan_to_num(R_t_E, nan=0.0, posinf=0.0, neginf=0.0)
        R_t_X[R_t_X > 1e5] = 0.0
        R_t_E[R_t_E > 1e5] = 0.0
        return R_t_X, R_t_E

    def compute_velocity_E(
            self,
            t,
            E_1_label,
            E_0_label
            ):
        B, N, _, = E_1_label.shape
        A1 = E_1_label.squeeze(-1)
        A0 = E_0_label.squeeze(-1)

        I = torch.eye(N, device=A1.device, dtype=A1.dtype).unsqueeze(0)
        L1 = 1.0 * laplacian_from_adjacency(A1) # (B, N, N)
        L0 = 1.0 * laplacian_from_adjacency(A0) # (B, N, N)
        L0_pinv= torch.linalg.pinv(L0)
        T_mtx = self.compute_T(L0, L1, L0_pinv)
        Lt_pinv = ((1-t)*I +t*T_mtx) @ L0_pinv @ ((1-t)*I +t*T_mtx)
        Lt = torch.linalg.pinv(Lt_pinv)
        At = adjacency_from_laplacian(Lt)
        dL_dt = 2.0 * Lt - T_mtx @ Lt - Lt @ T_mtx  # (B, N, N)
        dA_dt = adjacency_from_laplacian(dL_dt)  # (B, N, N)

        return dA_dt, At

def verify_geodesic(E0, E1, model, n_steps=1000, device='cuda'):
    """
    Integrate dA/dt = v(t, A) from t=0 to t=1 starting at A(0)=E0,
    and return final A and error norm to E1.
    """
    dt = 1.0 / n_steps
    B, N, N = E0.shape
    A = E0.clone().to(device)
    E0 = E0.to(device)
    E1 = E1.to(device)
    err_1 = []
    err_0 = []

    mean_fn = MeanMetric(cfg=None)

    for i in range(n_steps):
        t = i * dt
        A1 = E1.squeeze(-1)  # → (B, N, N)
        A0 = E0.squeeze(-1)  # → (B, N, N)

        v, At = model.compute_velocity_E(t, E1, E0)   # shape (B,N,N)
        At_construct = mean_fn.mean_fn_batched(A0, A1, torch.Tensor(B*[t]), ret_prob=True)
        A = A + dt * v
        print(At_construct.clamp(0,1))
        print(At-E1)
        err_1.append(torch.norm(At_construct.clamp(0,1) - A1))
        err_0.append(torch.norm(At_construct.clamp(0,1) - A0))

    # compute error
    return At, torch.Tensor(err_1), torch.Tensor(err_0)
