import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

def bce_loss(input, target, reduce=True):
    """
    Numerically stable version of the binary cross-entropy loss function.
    As per https://github.com/pytorch/pytorch/issues/751
    See the TensorFlow docs for a derivation of this formula:
    https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    Inputs:
    - input: PyTorch Tensor of shape (N, ) giving scores.
    - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
    Returns:
    - A PyTorch Tensor containing the mean BCE loss over the minibatch of
      input data.
    """
    neg_abs = -input.abs()
    loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
    if reduce:
        return loss.mean()
    else:
        return loss


def calculate_model_losses(args, pred, target, name, writer=None, counter=None,):
    total_loss = 0.0
    losses = {}
    rec_loss = F.l1_loss(pred, target)
    total_loss = add_loss(total_loss, rec_loss, losses, name, 1)
    
    #writer.add_scalar('Train_Loss_Rec_{}'.format(name), rec_loss, counter)

    return total_loss, losses
def calculate_model_losses_triplet(
        args,
        pred,                 
        target,                
        labels,                
        name='shape',
        writer=None,
        counter=None):


    # ---------- 1)  ----------
    pred_n   = F.normalize(pred,   dim=-1)
    target_n = F.normalize(target, dim=-1)
    dist_mat = 1.0 - pred_n @ target_n.T         

    B = dist_mat.size(0)
    pos_dist = dist_mat.diag()                    

    # ---------- 2) mask ----------
    same_cls = labels.unsqueeze(0) == labels.unsqueeze(1)   # (B,B) bool
    not_diag = ~torch.eye(B, dtype=torch.bool, device=pred.device)
    mask     = same_cls & not_diag                          

    # ---------- 3) Hardest  ----------
    large_val = 1e9
    neg_all   = dist_mat.masked_fill(~mask, large_val)      
    neg_dist  = neg_all.min(dim=1).values                   # (B,)

    if (neg_dist == large_val).any():
        neg_fallback = dist_mat.masked_fill(~not_diag, large_val).min(dim=1).values
        neg_dist = torch.where(neg_dist == large_val, neg_fallback, neg_dist)

    # ---------- 4) margin ----------
    start_m, end_m = 0.3, 0.45
    s_anneal       = 20_000
    cur_step       = counter or 0
    alpha          = min(cur_step / s_anneal, 1.0)          # 0 → 1
    cur_m          = start_m + (end_m - start_m) * alpha

    # ---------- 5) Triplet Loss ----------
    triplet_loss = F.relu(pos_dist - neg_dist + cur_m).mean()

    # ---------- 6) TensorBoard ----------
    if writer is not None and counter is not None:
        writer.add_scalar(f'Train_Loss_{name}_metric', triplet_loss, counter)
        writer.add_scalar(f'Train_Margin_{name}',      cur_m,       counter)

    return triplet_loss, {f'{name}_metric': triplet_loss.detach()}


def fullbank_info_nce(
        pred, target, labels,                 # (B,256)
        train_box_data,                       # {class_name: {obj_id: …}}
        code_dict,                            # {obj_id: np.array(256,)}
        label_classes,                        # {'chair': 1, 'table': 2, …}
        bank_cache,                           # {class_name: tensor(N,256)}
        tau: float = 0.05,
        topK: int = 256):                     # ★ Hardest-K

    device  = pred.device
    pred_n  = F.normalize(pred,   dim=-1)
    target_n= F.normalize(target, dim=-1)

    if not bank_cache:
        for cls_name, id_dict in train_box_data.items():
            vecs = np.vstack([code_dict[oid] for oid in id_dict]).astype('float32')
            bank_cache[cls_name] = F.normalize(torch.from_numpy(vecs).to(device), dim=-1)
        print(f'[InfoNCE] Bank built: {sum(v.shape[0] for v in bank_cache.values())} shapes → GPU')


    id2name = {v: k for k, v in label_classes.items()} 

    loss_sum, cls_cnt = 0.0, 0
    #print('labels: ', np.unique(np.asarray(labels)))
    for cls_id in np.unique(np.asarray(labels)):
        mask = (labels == cls_id)
        if not mask.any(): continue

        name = id2name[int(cls_id)]
        if name not in bank_cache:
            continue                                     

        p_cls, t_cls = pred_n[mask], target_n[mask]
        bank         = bank_cache[name]                  

        sims = p_cls @ bank.T                            # (b, N_cls)
        same = (bank.unsqueeze(0) == t_cls.unsqueeze(1)).all(-1)  # (b , N_cls)
        sims = sims.masked_fill(same, -1.1) 

        # --- Hardest-K  ---
        if bank.size(0) > topK:
            neg, _ = sims.topk(min(topK, (sims > -1).sum(1).min().item()),
                       dim=1, largest=True)   #remove the positive sample
        else:
            neg = sims                                   # N_cls ≤ K

        # --- InfoNCE logits ---
        pos    = (p_cls * t_cls).sum(-1, keepdim=True)   # (b,1)
        logits = torch.cat([pos, neg], 1) / tau

        tgt = torch.zeros(len(p_cls), dtype=torch.long, device=device)
        loss_sum += F.cross_entropy(logits, tgt)
        cls_cnt  += 1

    loss_l1  = F.l1_loss(pred, target)
    #return loss_sum / max(cls_cnt, 1)+ loss_l1
    return loss_l1

def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1):
    curr_loss_weighted = curr_loss * weight
    loss_dict[loss_name] = curr_loss_weighted.item()
    if total_loss is not None:
        return total_loss + curr_loss_weighted
    else:
        return curr_loss_weighted
    retur

class VQLoss(nn.Module):
    def __init__(self, codebook_weight=1.0):
        super().__init__()
        self.codebook_weight = codebook_weight

    def forward(self, codebook_loss, inputs, reconstructions, split="train"):
        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())

        nll_loss = rec_loss
        nll_loss = torch.mean(nll_loss)

        loss = nll_loss + self.codebook_weight * codebook_loss.mean()

        log = {
            "loss_total": loss.clone().detach().mean(),
            "loss_codebook": codebook_loss.detach().mean(),
            "loss_nll": nll_loss.detach().mean(),
            "loss_rec": rec_loss.detach().mean(),
        }

        return loss, log