import numpy as np
import networkx as nx
import random
import copy
import torch
import  matplotlib.pyplot as plt
def calculate_robust_gw_loss(C, D, X, alpha, beta):
    if alpha is None:
        alpha = torch.zeros_like(X)
    if beta is None:
        beta = torch.zeros_like(X)
    Sa = X - alpha
    Sb = X - beta

    C2 = C**2
    D2 = D**2
    
    a_Sa = Sa.sum(dim=1)  
    a_Sb = Sb.sum(dim=1)  
    term1 = torch.sum(C2 * (a_Sa.view(-1, 1) @ a_Sb.view(1, -1)))

    b_Sa = Sa.sum(dim=0)  
    b_Sb = Sb.sum(dim=0)  
    term2 = torch.sum(D2 * (b_Sa.view(-1, 1) @ b_Sb.view(1, -1)))

    term3 = -2 * torch.trace(Sa.T @ C @ Sb @ D.T)
    robust_gw_loss = term1 + term2 + term3
    return robust_gw_loss


def solve_lp_with_row_sum_limit(grad_matrix, upper_bounds_matrix, row_limits):
    M, N = grad_matrix.shape
    
    if not isinstance(row_limits, torch.Tensor):
        row_limits = torch.tensor(row_limits, device=grad_matrix.device, dtype=grad_matrix.dtype)
    if row_limits.dim() == 1:
        row_limits = row_limits.view(-1, 1)
    if row_limits.numel() == 1:
        row_limits = row_limits.expand(M, 1)

    sorted_grad, sorted_indices = torch.sort(grad_matrix, dim=1)
    sorted_caps = torch.gather(upper_bounds_matrix, 1, sorted_indices)
    
    cum_caps = torch.cumsum(sorted_caps, dim=1)
    
    total_row_capacity = cum_caps[:, -1].unsqueeze(1)
    effective_limits = torch.min(row_limits, total_row_capacity)
    
    mask_fully_filled = cum_caps <= effective_limits

    solution_sorted = torch.zeros_like(grad_matrix)
    solution_sorted[mask_fully_filled] = sorted_caps[mask_fully_filled]
    
    cutoff_indices = mask_fully_filled.sum(dim=1, keepdim=True)
    cutoff_indices = torch.clamp(cutoff_indices, 0, N - 1)
    
    current_filled_sum = solution_sorted.sum(dim=1, keepdim=True)
    remainder = effective_limits - current_filled_sum
    remainder = torch.clamp(remainder, min=0) 
    
    solution_sorted.scatter_add_(1, cutoff_indices, remainder)
    
    solution_original = torch.zeros_like(grad_matrix)
    solution_original.scatter_(1, sorted_indices, solution_sorted)
    
    return solution_original
def solve_lp_with_box_and_sum_constraints(X, grad_matrix, target_sum):
    if X.shape != grad_matrix.shape:
        raise ValueError("X 和 grad_matrix 必须具有相同的形状。")

    if not isinstance(target_sum, torch.Tensor):
        target_sum = torch.tensor(target_sum, device=X.device, dtype=X.dtype)

    if not (0 <= target_sum <= X.sum()):
        if not torch.isclose(target_sum, X.sum(), atol=1e-6) and target_sum > X.sum():
             raise ValueError(f"target_sum ({target_sum.item()}) 必须在 [0, {X.sum().item()}] 范围内。")
        target_sum = torch.clamp(target_sum, 0, X.sum())


    X_flat = X.flatten()
    g_flat = grad_matrix.flatten()
    sorted_indices = torch.argsort(g_flat)
    sorted_X = X_flat[sorted_indices]
    cumulative_X = torch.cumsum(sorted_X, dim=0)
    cutoff_idx = torch.searchsorted(cumulative_X, target_sum)
    solution_sorted = torch.zeros_like(X_flat)
    if cutoff_idx > 0:
        solution_sorted[:cutoff_idx] = sorted_X[:cutoff_idx]
    if cutoff_idx < len(sorted_X):
        prev_sum = cumulative_X[cutoff_idx - 1] if cutoff_idx > 0 else 0
        remainder = target_sum - prev_sum
        solution_sorted[cutoff_idx] = torch.min(remainder, sorted_X[cutoff_idx])

    solution_flat = torch.zeros_like(X_flat)
    solution_flat[sorted_indices] = solution_sorted
    
    return solution_flat.reshape(grad_matrix.shape)
def compute_gw_gradient(C, D, T, other):
    dtype = C.dtype
    D = D.to(dtype)
    T = T.to(dtype)
    if torch.is_tensor(other):
        other = other.to(dtype)
    else:
        other = torch.tensor(other, device=T.device, dtype=dtype)

    S_b = T - other
    C2 = C**2
    D2 = D**2
    a_Sb = S_b.sum(dim=1, keepdim=True)  
    term1 = (C2 @ a_Sb) @ torch.ones(1, S_b.shape[1], device=S_b.device, dtype=dtype)
    b_Sb = S_b.sum(dim=0, keepdim=True) 
    term3 = torch.ones(S_b.shape[0], 1, device=S_b.device, dtype=dtype) @ (b_Sb @ D2)
    term2 = -2 * C @ S_b @ D.T
    grad = term1 + term3+ term2
    return -grad
def init_by_degree(Adj1, Adj2, a, b, rho=0.1):
    device = Adj1.device
    d1 = Adj1.sum(dim=1).view(-1, 1) 
    d2 = Adj2.sum(dim=1).view(1, -1) 
    degree_diff = torch.abs(d1 - d2) 
    K = torch.exp(-degree_diff / degree_diff.mean()) 
    X_init = K
    for _ in range(5):
        X_init = X_init * (a / (X_init @ torch.ones(b.shape, device=b.device, dtype=b.dtype)))
        X_init = X_init * (b.T / (X_init.T @ torch.ones(a.shape, device=a.device, dtype=a.dtype)).T)
    return X_init
def init_alpha_beta_by_degree(A, B, S_alpha, S_beta):
    device = A.device
    M = A.shape[0]
    N = B.shape[0]
    d1 = A.sum(dim=1) 
    d2 = B.sum(dim=1) 
    diff_mat = torch.abs(d1.unsqueeze(1) - d2.unsqueeze(0))
    min_diff_A = diff_mat.min(dim=1).values 
    min_diff_B = diff_mat.min(0).values 
    score_A = min_diff_A.view(-1, 1) + 1e-2
    score_B = min_diff_B.view(-1, 1) + 1e-2
    alpha_init = score_A @ torch.ones(1, N, device=device)
    beta_init = torch.ones(M, 1, device=device) @ score_B.T 
    alpha_init = alpha_init / (alpha_init.sum() + 1e-10) * S_alpha

    beta_init = beta_init / (beta_init.sum() + 1e-10) * S_beta
    
    return alpha_init, beta_init



def mask_guided_gw_solver(A, B, a=None, b=None, S_alpha=0.25, S_beta=0.25, 
                          X_init=None, alpha_init=None, beta_init=None,
                          outer_epochs=20,inner_steps_X=1000,inner_loop_eps=1e-5, 
                            rho_decay=0.9,warm=0,
                          eps=1e-5, rho=1e-1, min_rho=1e-1, scaling=1.0, 
                          early_stop_patience=2000, plot=False,graph_id=1,degree_init=False,degree_init_alpha=False):
    device = A.device
    dtype = A.dtype
    if B.dtype != dtype:
        B = B.to(device=device, dtype=dtype)
    plot_filename='photo/'+f'cost_plot_{graph_id}_{S_alpha}.png'
    if a is None:
        a = torch.ones(A.shape[0], 1, device=device, dtype=dtype) / A.shape[0]
    else:
        a = a.to(device=device, dtype=dtype)
    if b is None:
        b = torch.ones(B.shape[0], 1, device=device, dtype=dtype) / B.shape[0]
    else:
        b = b.to(device=device, dtype=dtype)
    if degree_init:
        if X_init is None:
            X = init_by_degree(A, B, a, b, rho)
    else:
        if X_init is None:
            X = a @ b.T + 1e-5 * torch.rand_like(a @ b.T)
        else:
            X = X_init.to(device=device, dtype=dtype).clone()
    if alpha_init is None or beta_init is None:
        if degree_init_alpha:
            print("Initializing alpha/beta using degree mismatch...")
            a_init, b_init = init_alpha_beta_by_degree(A, B, S_alpha, S_beta)
            
            if alpha_init is None: alpha = a_init
            else: alpha = alpha_init.clone()
            
            if beta_init is None: beta = b_init
            else: beta = beta_init.clone()
        else:
            numel = max(X.numel(), 1)
            if alpha_init is None:
                alpha = torch.full_like(X, S_alpha / numel, device=device)
            else:
                alpha = alpha_init.to(device=device, dtype=dtype).clone()
            if beta_init is None:
                beta = torch.full_like(X, S_beta / numel, device=device)
            else:
                beta = beta_init.to(device=device, dtype=dtype).clone()
    else:
        alpha = alpha_init.to(device=device, dtype=dtype).clone()
        beta = beta_init.to(device=device, dtype=dtype).clone()
    X = X.to(device=device, dtype=dtype)
    alpha = alpha.to(device=device, dtype=dtype)
    beta = beta.to(device=device, dtype=dtype)
    obj_list = []
    for ii in range(outer_epochs):
        rho = max(min_rho, rho * rho_decay)
        if ii < warm :
            X_old = X.clone()
            cost= compute_gw_gradient(A, B, X, 0)
            X = torch.exp(cost / rho) * X
            X = X * (a / (X @ torch.ones(b.shape, device=b.device, dtype=b.dtype)+ 1e-10))
            cost= compute_gw_gradient(A, B, X, 0)
            X = torch.exp(cost / rho) * X
            X = X * (b.T / (X.T @ torch.ones(a.shape, device=a.device, dtype=a.dtype)+ 1e-10).T)
            continue
            
        for iii in range(inner_steps_X):
            X_old = X.clone()
            cost= compute_gw_gradient(A, B, X, beta)
            X = torch.exp(cost / rho) * X
            X = X * (a / (X @ torch.ones(b.shape, device=b.device, dtype=b.dtype)+ 1e-10))
            cost= compute_gw_gradient(A, B, X, alpha)
            X = torch.exp(cost / rho) * X
            X = X * (b.T / (X.T @ torch.ones(a.shape, device=a.device, dtype=a.dtype)+ 1e-10).T)
            if iii % 2 == 0: 
                diff = torch.sum(torch.abs(X - X_old)) / (torch.sum(torch.abs(X_old)) + 1e-10)
                if diff < inner_loop_eps:
                    break
        grad_alpha = compute_gw_gradient(A, B, X, beta)
        alpha = solve_lp_with_box_and_sum_constraints(X,grad_alpha,S_alpha)
        grad_beta = compute_gw_gradient(A, B, X, alpha)
        beta = solve_lp_with_box_and_sum_constraints(X,grad_beta,S_beta)
    
        if ii > early_stop_patience and ii % 10 == 0:
            objective = calculate_robust_gw_loss(A, B, X, alpha, beta)
            if len(obj_list) > 0 and abs(obj_list[-1]) > 1e-9:
                relative_change = abs(objective.item() - obj_list[-1]) / abs(obj_list[-1])
                if relative_change < eps:
                    print(f'iter:{ii+1}, Robust GW loss relative change ({relative_change:.2e}) < eps ({eps:.2e})')
                    break
            obj_list.append(objective.item())
    
    if plot and len(obj_list) > 1:
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(obj_list) + 1), obj_list, marker='.', linestyle='-')
        plt.title('Cost vs. Iterations')
        plt.xlabel('Outer Iteration')
        plt.ylabel('Robust GW Loss (Cost)')
        plt.grid(True)
        plt.savefig(plot_filename)
        plt.close() 

    return X, alpha, beta, obj_list
