import torch
import torch.nn as nn
import torch.nn.functional as F

def manifold_prediction_loss(
    reconstructed: torch.Tensor,
    target: torch.Tensor
):
    return F.mse_loss(reconstructed, target)

def manifold_reconstruction_loss(
    reconstructed: torch.Tensor,
    target: torch.Tensor
):
    return F.mse_loss(reconstructed, target)

def manifold_alignment_loss(
    observed_embeddings: torch.Tensor,   
    individual_embeddings: torch.Tensor,   
) -> torch.Tensor:
    z1 = F.normalize(observed_embeddings, dim=-1)
    z2 = F.normalize(individual_embeddings, dim=-1)
    cos_sim = (z1 * z2).sum(dim=-1)    
    loss = (1.0 - cos_sim).mean()
    return loss

class ManifoldTopologyLoss(nn.Module):
    def __init__(self, temperature=0.3, reduction='batchmean'):
        super().__init__()
        self.temperature = temperature
        self.reduction = reduction

    def _compute_pairwise_distribution(self, embeddings, mask_diagonal=True):
        dists = torch.cdist(embeddings, embeddings, p=2)
        logits = -dists / self.temperature
        if mask_diagonal:
            b_size = dists.shape[-2]
            eye_mask = torch.eye(b_size, device=embeddings.device).bool()
            logits.masked_fill_(eye_mask, float('-inf'))

        probs = F.softmax(logits, dim=-1)
        
        return probs

    def forward(self, global_embeddings, individual_embeddings, num_series):
        B = global_embeddings.shape[0]
        N = num_series
        P_global = self._compute_pairwise_distribution(global_embeddings)
        P_global = P_global.detach()
        indiv_view = individual_embeddings.view(B, N, -1) 
        indiv_per_series = indiv_view.permute(1, 0, 2) 

        P_indiv = self._compute_pairwise_distribution(indiv_per_series)

        P_global_expanded = P_global.unsqueeze(0).expand(N, -1, -1)
        log_P_indiv = torch.log(P_indiv + 1e-8)
        
        loss = F.kl_div(log_P_indiv, P_global_expanded, reduction=self.reduction)
        
        return loss


def manifold_contrastive_loss_sameseries(
    individual_embeddings: torch.Tensor, 
    series_ids: torch.Tensor,             
    time_indices: torch.Tensor,           
    pos_time_threshold: int = 10,
    neg_time_threshold: int = 20,
    temp: float = 0.3,
) -> torch.Tensor:

    device = individual_embeddings.device
    BN, D = individual_embeddings.shape
    z = F.normalize(individual_embeddings, dim=-1)  

    sim = z @ z.t() 

    eye = torch.eye(BN, dtype=torch.bool, device=device)
    series_ids = series_ids.view(-1, 1)       
    time_indices = time_indices.view(-1, 1)   

    same_series = (series_ids == series_ids.t())           
    time_diff = (time_indices - time_indices.t()).abs()     

    pos_mask = same_series & (time_diff < pos_time_threshold) & ~eye

    neg_mask = same_series & (time_diff >= neg_time_threshold) & ~eye

    valid_mask = pos_mask | neg_mask

    if not valid_mask.any():
        return torch.zeros((), device=device)

    sim_valid = sim.masked_fill(~valid_mask, float('-inf'))

    log_prob = F.log_softmax(sim_valid / temp, dim=1)  

    if not pos_mask.any():
        return torch.zeros((), device=device)

    pos_log_prob = log_prob[pos_mask]
    loss = -pos_log_prob.mean()

    return loss


def manifold_contrastive_loss(
    individual_embeddings: torch.Tensor,  
    series_ids: torch.Tensor,            
    time_indices: torch.Tensor,           
    pos_time_threshold: int = 10,
    neg_time_threshold: int = 20,
    temp: float = 0.3,
) -> torch.Tensor:

    device = individual_embeddings.device

    BN, D = individual_embeddings.shape

    z = F.normalize(individual_embeddings, dim=-1) 

    sim = z @ z.t() 

    eye = torch.eye(BN, dtype=torch.bool, device=device)
    sim = sim.masked_fill(eye, float('-inf'))

    series_ids = series_ids.view(-1, 1)       
    time_indices = time_indices.view(-1, 1)   

    same_series = (series_ids == series_ids.t())            
    time_diff = (time_indices - time_indices.t()).abs()     

    pos_mask = same_series & (time_diff < pos_time_threshold)
    pos_mask = pos_mask & ~eye 

    diff_series = ~same_series
    far_enough = same_series & (time_diff >= neg_time_threshold)

    neg_mask = diff_series | far_enough
    neg_mask = neg_mask & ~eye 

    valid_mask = pos_mask | neg_mask

    if not valid_mask.any():
        return torch.zeros((), device=device)

    sim_valid = sim.masked_fill(~valid_mask, float('-inf'))

    log_prob = F.log_softmax(sim_valid / temp, dim=1)

    if not pos_mask.any():

        return torch.zeros((), device=device)

    pos_log_prob = log_prob[pos_mask]
    loss = -pos_log_prob.mean()

    return loss