import torch
from copy import deepcopy

def _is_model_parameter(variable, model):
    # Check for direct parameter match
    if any(variable is param for param in model.parameters()):
        return True
    
    # Check if variable shares storage with any parameter
    var_ptr = variable.data_ptr()
    var_storage = variable.untyped_storage()
    
    for param in model.parameters():
        if variable is param:
            return True
            
        param_storage = param.untyped_storage()
        if var_storage.data_ptr() == param_storage.data_ptr():
            param_start = param.data_ptr()
            param_bytes = param.numel() * param.element_size()
            param_end = param_start + param_bytes
            
            var_bytes = variable.numel() * variable.element_size()
            var_end = var_ptr + var_bytes
            
            if var_ptr >= param_start and var_end <= param_end:
                return True
                
    return False

def _gradient_core(model, inputs, targets, criterion, variable):
    """Core gradient computation shared by all interfaces"""
    model.eval()
    predictions = model(inputs)
    loss = criterion(predictions, targets)
    return torch.autograd.grad(
        outputs=loss,
        inputs=variable,
        retain_graph=False,
        allow_unused=False
    )[0]

def calculate_gradients(model, inputs, targets, criterion, variable='inputs'):
    """
    Unified gradient calculation interface
    
    Args:
        model: PyTorch model
        inputs: Model inputs
        targets: Ground truth targets
        criterion: Loss function
        variable: Gradient target specification ('inputs', 'predictions', or parameter tensor)
    
    Returns:
        Requested gradient tensor
    """
    # Handle device placement
    try:
        device = next(model.parameters()).device
    except StopIteration:
        device = torch.device('cpu')

    # ensure no reduction (such as mean or sum) will be applied when calculating loss function 
    loss_function = deepcopy(criterion)
    loss_function.reduction = 'none'
    
    inputs = torch.as_tensor(inputs, device=device)
    targets = torch.as_tensor(targets, device=device)
    
    # Process different gradient targets
    if variable == 'inputs':
        inputs = inputs.requires_grad_(True)

    # forward pass
    model.eval()
    predictions = model(inputs)
    loss = criterion(predictions, targets)

    if variable == 'inputs':
        return torch.autograd.grad(outputs=loss, inputs=inputs, retain_graph=False, allow_unused=False)[0]

    elif variable == 'predictions':
        return torch.autograd.grad(outputs=loss, inputs=predictions, retain_graph=False, allow_unused=False)[0]

    else:
        return torch.autograd.grad(outputs=loss, inputs=variable, retain_graph=False, allow_unused=False)[0]


def calculate_gradients_over_inputs(model, inputs, targets, criterion):
    return calculate_gradients(model, inputs, targets, criterion, variable='inputs')

def calculate_gradients_over_predictions(model, inputs, targets, criterion):
    return calculate_gradients(model, inputs, targets, criterion, variable='predictions')

def calculate_gradients_over_parameters(model, inputs, targets, criterion, parameters):
    assert _is_model_parameter(variable, model), "Invalid parameter tensor"
    return calculate_gradients(model, inputs, targets, criterion, variable=parameters)


# initialize parameter wB_new as a uniform vector with all elements same but non-zero values
def initialize_wB_new(WB):
    dim = WB.shape[0]
    return torch.ones(dim) * torch.mean(WB, dtype=torch.float32).item()

'''

# initialize parameter wB_new as a uniform vector with all elements same but non-zero values
def initialize_wA_new_and_bA_new(x_fold, x_target, dL_dx_before, dL_df_before, wB_new):
    delta_x = x_target - x_fold

    v = dL_dx_before.detach().clone()

    c = torch.dot(dL_df_before, wB_new)

    K = torch.dot(v, delta_x)

    if torch.abs(K) < 1e-10:
        # if K ≈ 0, any wA_new that satisfy the constraints can be used
        # here we simply choose a wA_new that is orthogonal to the delta_x
        proj_v = v - (K / torch.dot(delta_x, delta_x)) * delta_x
        if torch.norm(proj_v) > 1e-10:
            w = proj_v / torch.norm(proj_v)
        else:
            # if v is already parallel to the delta_x,
            # simply choose any unit vector that is orthogonal to the delta_x
            w = torch.zeros_like(delta_x)
            min_idx = torch.argmin(torch.abs(delta_x))

            w[min_idx] = 1.0
            w = w - (torch.dot(w, delta_x) / torch.dot(delta_x, delta_x)) * delta_x
            w = w / torch.norm(w)

    else:
        # calculate the projection of v onto the orthogonal complement of (x_t - x_f)
        proj_v_perp = v - (torch.dot(v, delta_x) / torch.dot(delta_x, delta_x)) * delta_x

        # optimal solution：w* = -(1/c) * proj_v_perp
        w = -(1/c) * proj_v_perp

    b = -torch.dot(w, x_fold)

    assert torch.allclose(torch.dot(w, x_fold)+b, torch.tensor(0.0)) and torch.allclose(torch.dot(w, x_target)+b, torch.tensor(0.0)), "{} - {}".format(torch.dot(w, x_fold)+b, torch.dot(w, x_target)+b)

    return w, b

'''

def get_orthogonal_unit_vector(d: torch.Tensor) -> torch.Tensor:
    """生成與輸入向量 d 正交的單位向量。"""
    d_sq = torch.dot(d, d)
    if d_sq < 1e-10:
        raise ValueError("輸入向量接近零，無法生成正交向量。")
    
    # 嘗試隨機向量
    for _ in range(3):
        u = torch.randn_like(d)
        proj = (torch.dot(u, d) / d_sq) * d
        u_orth = u - proj
        norm_orth = torch.norm(u_orth)
        if norm_orth > 1e-10:
            return u_orth / norm_orth
    
    # 嘗試標準基向量
    for i in range(len(d)):
        u = torch.zeros_like(d)
        u[i] = 1.0
        proj = (torch.dot(u, d) / d_sq) * d
        u_orth = u - proj
        norm_orth = torch.norm(u_orth)
        if norm_orth > 1e-10:
            return u_orth / norm_orth
    
    raise RuntimeError("無法生成正交向量。")



def initialize_wA_new_and_bA_new(x_f, x_t, dL_dx_before, dL_df_before, wB_new, eps=1e-8, M_large=1e10, M_small=1e-5):
    v = dL_dx_before.detach().clone()

    c = torch.dot(dL_df_before, wB_new)

    """
    求解滿足條件的超平面參數 w 和 b。
    
    參數:
        x_f, x_t: 超平面必須經過的兩點 (n 維向量)。
        v: 常數向量 (n 維)。
        c: 常數標量。
        eps: 避免數值問題的小常數。
        M_large: k <= 0 時 w 的放大係數。
        M_small: k > 0 且解析解接近零時的小係數。
        
    返回:
        w: 超平面法向量 (n 維)。
        b: 超平面偏移量。
    """
    d = x_t - x_f
    d_sq = torch.dot(d, d)
    if d_sq < eps:
        raise ValueError("x_f 和 x_t 相同，無法定義超平面。")
    
    # 計算 k = v · d
    k = torch.dot(v, d)
    
    # 處理 c = 0 的情況
    if abs(c) < eps:
        w_unit = get_orthogonal_unit_vector(d)
        w = w_unit
        b = -torch.dot(w, x_f)
        return w, b
    
    # k > 0: 最小化 ||c*w + v||
    if k > 0:
        w = -v / c + (k / (c * d_sq)) * d
        # 避免 w 接近零向量
        if torch.norm(w) < eps:
            w_unit = get_orthogonal_unit_vector(d)
            w = M_small * w_unit
        b = -torch.dot(w, x_f)
        return w, b
    
    # k <= 0: 取大範數的 w (最大化 ||c*w + v||)
    w_unit = get_orthogonal_unit_vector(d)
    w = M_large * w_unit
    b = -torch.dot(w, x_f)
    
    assert torch.allclose(torch.dot(w, x_f)+b, torch.tensor(0.0)) and torch.allclose(torch.dot(w, x_t)+b, torch.tensor(0.0)), "{} - {}".format(torch.dot(w, x_f)+b, torch.dot(w, x_t)+b)

    return w, b