import torch
import torch.nn as nn
import torch.nn.functional as F

def split_func(data, view_nums, input_dims):
    if not isinstance(data, torch.Tensor) or data.dim() != 2:
        raise ValueError("Input data must be a 2D PyTorch tensor.")
    if len(input_dims) != view_nums:
        raise ValueError("Length of input_dims must match view_nums.")
    total_features = data.shape[1] - 1
    if sum(input_dims) != total_features:
        raise ValueError(f"Sum of input_dims ({sum(input_dims)}) != {total_features}.")
    views, start = [], 0
    for dim in input_dims:
        end = start + dim
        views.append(data[:, start:end])
        start = end
    return views

class AutoEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, input_dim),
            nn.Sigmoid(),
        )
    def forward(self, x):
        h = self.encoder(x)
        xh = self.decoder(h)
        return xh, h

def neighbor_preserve_loss(X, Z, eta=1e-3):
    N = X.size(0)
    Dx = X.size(1)
    Dz = Z.size(1)
    dist_x = torch.cdist(X, X).pow(2) / Dx
    dist_z = torch.cdist(Z, Z).pow(2) / Dz
    W = 1.0 / (dist_x + eta)
    W = W / (W.sum() + 1e-8)
    L = (W * (dist_z - dist_x).pow(2)).sum() / (N * N)
    return L

@torch.no_grad()
def compute_phi(x, xh, alpha=0.1):
    return 2 * torch.sigmoid(-alpha * (x - xh).pow(2).sum(1, keepdim=True))  # [N,1]

@torch.no_grad()
def compute_psi(M_list, K):
    V = len(M_list)
    N = M_list[0].size(0)
    if V == 1:
        return torch.ones(N, device=M_list[0].device)

    idx_rows = []
    for v in range(V):
        I, J = torch.nonzero(M_list[v] > 0.5, as_tuple=True)
        row_idx = [None] * N
        for i, j in zip(I.tolist(), J.tolist()):
            if row_idx[i] is None:
                row_idx[i] = []
            row_idx[i].append(j)
        row_idx = [torch.tensor(s, device=M_list[v].device, dtype=torch.long) if s is not None else torch.empty(0, dtype=torch.long, device=M_list[v].device) for s in row_idx]
        idx_rows.append(row_idx)

    psi = torch.zeros(N, device=M_list[0].device)
    denom = 0
    for a in range(V):
        for b in range(a+1, V):
            denom += 1
            Ra, Rb = idx_rows[a], idx_rows[b]
            overlap = torch.zeros(N, device=psi.device)
            for i in range(N):
                if Ra[i].numel() == 0 or Rb[i].numel() == 0:
                    continue
                inter = torch.isin(Ra[i], Rb[i]).float().sum()
                overlap[i] = inter / max(1, K)
            psi += overlap
    psi = psi / max(1, denom)
    psi = torch.nan_to_num(psi, nan=0.0, posinf=1.0, neginf=0.0).clamp(0.0, 1.0)
    return psi  # [N]

@torch.no_grad()
def compute_W_with_psi(phis, psi, beta=8.0, gamma=0.25):
    V, N, _ = phis.shape
    phi_mean = phis.mean(0).squeeze(-1)  # [N]
    diff2 = (phi_mean[:, None] - phi_mean[None, :]).pow(0.25)
    W_phi = torch.exp(-beta * diff2)

    psi_pair = (psi[None, :]) * 0.5
    W = W_phi * (psi_pair.clamp(0.0, 1.0).pow(gamma))
    W = torch.nan_to_num(W, nan=0.0, posinf=10.0, neginf=0.0).clamp(min=1e-4, max=10.0)
    return W  # [N,N]

@torch.no_grad()
def build_neighbors(H, W, K):
    N = H.size(0)
    eps = 1e-8
    dist2 = torch.cdist(H, H).pow(2)          
    score = dist2 * (1.0 / (W + eps))            
    score = score + 1e6 * torch.eye(N, device=H.device)  
    _, idx = torch.topk(score, K, dim=1, largest=False)
    M = torch.zeros_like(score)
    M.scatter_(1, idx, 1.0)
    return M

@torch.no_grad()
def build_neighbors_unweighted(H, K):
    N = H.size(0)
    dist2 = torch.cdist(H, H).pow(2)
    dist2 = dist2 + 1e6 * torch.eye(N, device=H.device)
    _, idx = torch.topk(dist2, K, dim=1, largest=False)
    M = torch.zeros_like(dist2)
    M.scatter_(1, idx, 1.0)
    return M


import torch
import torch.nn.functional as F

def angular_loss(H_list, v, M, eps=1e-8):
    H_v = H_list[v]
    V = len(H_list)
    I, J = torch.nonzero(M > 0.5, as_tuple=True)
    if I.numel() == 0:
        return H_v.new_tensor(0.0)

    dv = H_v[I] - H_v[J]

    dc_sum = 0
    cnt = 0
    for u in range(V):
        if u == v:
            continue
        du = H_list[u][I] - H_list[u][J]
        du = du / (du.norm(p=2, dim=1, keepdim=True) + eps)
        dc_sum = dc_sum + du if cnt > 0 else du
        cnt += 1
    dc = dc_sum / max(1, cnt)

    dc = F.normalize(dc, p=2, dim=1)

    cos_sim = (dv * dc).sum(1)
    loss = (1 - cos_sim).mean()
    return torch.nan_to_num(loss, nan=0.0, posinf=1e3, neginf=0.0)

def distance_loss(H_list, v, M):
    H_v = H_list[v]
    V = len(H_list)
    I, J = torch.nonzero(M > 0.5, as_tuple=True)
    if I.numel() == 0:
        return H_v.new_tensor(0.0)

    dv = (H_v[I] - H_v[J]).norm(p=2, dim=1)
    dc = None
    cnt = 0
    for u in range(V):
        if u == v: continue
        du = (H_list[u][I] - H_list[u][J]).norm(p=2, dim=1)
        dc = du if dc is None else (dc + du)
        cnt += 1
    dc = dc / max(1, cnt)

    loss = (dv - dc).pow(2).mean()
    return torch.nan_to_num(loss, nan=0.0, posinf=1e3, neginf=0.0)

def loss_SRC(phi_sample, psi_sample, labels):
    y = labels.float()
    return ((phi_sample - (1 - y))**2 + (psi_sample - (1 - y))**2).mean()

def loss_SCLP(H, labels, num_pairs=256, margin=1.0):
    Z = torch.stack(H, dim=0).mean(0)  
    y = labels
    N = len(y)

    norm_idx = (y == 0).nonzero(as_tuple=True)[0]
    anom_idx = (y == 1).nonzero(as_tuple=True)[0]

    loss = 0.0

    # normal-normal (pull together)
    if len(norm_idx) > 1:
        i = norm_idx[torch.randint(0, len(norm_idx), (num_pairs,))]
        j = norm_idx[torch.randint(0, len(norm_idx), (num_pairs,))]
        loss += ((Z[i] - Z[j]).pow(2).sum(1)).mean()

    # normal-anomaly (push apart)
    if len(norm_idx) > 0 and len(anom_idx) > 0:
        i = norm_idx[torch.randint(0, len(norm_idx), (num_pairs,))]
        j = anom_idx[torch.randint(0, len(anom_idx), (num_pairs,))]
        d = (Z[i] - Z[j]).pow(2).sum(1).sqrt()
        loss += torch.relu(margin - d).pow(2).mean()

    return loss


def loss_SOSL(S, labels, num_pairs=256, margin=0.5):
    y = labels
    norm = (y == 0).nonzero(as_tuple=True)[0]
    anom = (y == 1).nonzero(as_tuple=True)[0]

    if len(norm) == 0 or len(anom) == 0:
        return torch.tensor(0.0, device=S.device)

    i = anom[torch.randint(0, len(anom), (num_pairs,))]
    j = norm[torch.randint(0, len(norm), (num_pairs,))]

    return torch.relu(margin - (S[i] - S[j])).mean()


@torch.no_grad()
def compute_pi(err_per_view, tau_pi=0.1):
    pi = torch.softmax(-err_per_view / (tau_pi + 1e-6), dim=0)
    pi = torch.nan_to_num(pi, nan=1.0/len(err_per_view))
    return pi
