import torch
import torch.nn as nn
import pdb

def norm(data):
    l2=torch.norm(data, p = 2, dim = -1, keepdim = True)
    return torch.div(data, l2)

def dynamic_k_(att, clusters, out_th):
    out_th = float(out_th)
    batch_size, num_segments = att.shape
    final_anomaly_scores = []
    selected_k= []
    selected_indices_info = []  

    mean_output = torch.quantile(att, out_th, axis = 1)

    for b in range(batch_size):
        unique_clusters = torch.unique(clusters[b])
        intermediate_scores = []
        batch_selected_indices = [] 

        for cluster in unique_clusters:
            cluster_indices = (clusters[b] == cluster).nonzero(as_tuple=True)[0]
            cluster_scores = att[b, cluster_indices]
            max_score = cluster_scores.max()
            if max_score > mean_output[b]:
                intermediate_scores.append(max_score)
                batch_selected_indices.append(cluster_indices[cluster_scores.argmax().item()].item())

        selected_k.append(len(intermediate_scores))
        selected_indices_info.append(batch_selected_indices)
        if intermediate_scores:
            mean_score_ = torch.tensor(intermediate_scores, device=att.device).mean().item()
            mean_score = torch.stack(intermediate_scores).mean()
            # pdb.set_trace()

        else:
            mean_score = torch.topk(att[b], 1, dim=-1)[0].mean()
        
        final_anomaly_scores.append(mean_score)
    final_anomaly_scores = torch.stack(final_anomaly_scores)

    # final_anomaly_scores = torch.tensor(final_anomaly_scores, device = att.device)
    return final_anomaly_scores, selected_k, selected_indices_info

def calculate_uncertainty(self, att, label, uncertainty, ab_sim, n_sim, u_th):
    uncertainty = uncertainty.to(att.device)
    loss_values = []
    
    
    for i in range(att.shape[0]):  # Iterate over each video (64 total)
        mask = (uncertainty[i] > u_th) & (ab_sim[i] > 0.7)


        selected_idx = torch.nonzero(mask, as_tuple=True)[0]  # Get indices
        
        if selected_idx.numel() > 0: 
            preds = att[i, selected_idx]
            targets = torch.ones_like(preds, device=att.device)  
            
            loss = self.bce__(preds, targets)
            loss_values.append(loss)

    if loss_values:
        return torch.stack(loss_values).mean() 

    else:
        return torch.tensor(0.0, device=att.device, requires_grad=True)
    
def calculate_ceratinty(self, att, label, uncertainty, ab_sim, n_sim, u_th):
    uncertainty = uncertainty.to(att.device)
    loss_values = []
    mean_output = torch.quantile(att, 0.94, axis = 1)

    
    for i in range(att.shape[0]):  # Iterate over each video (64 total)
        mask = (uncertainty[i] < u_th) & (att[i] > mean_output[i])


        selected_idx = torch.nonzero(mask, as_tuple=True)[0]  # Get indices
        
        if selected_idx.numel() > 0: 
            preds = att[i, selected_idx]
            targets = torch.ones_like(preds, device=att.device)  
            
            loss = self.bce__(preds, targets)
            loss_values.append(loss)

    if loss_values:
        return torch.stack(loss_values).mean() 

    else:
        return torch.tensor(0.0, device=att.device, requires_grad=True)

def margin_loss_(self, att, label):

    # Compute margin for each video: max prediction - min prediction
    max_pred = torch.max(att, dim=-1)[0]  # Shape: (batch_size,)
    min_pred = torch.min(att, dim=-1)[0]  # Shape: (batch_size,)
    margin = max_pred - min_pred  # Shape: (batch_size,)

    # Apply BCE loss (with reduction='none')
    margin_loss = self.bce_(margin, label.float())  # Shape: (batch_size,)

    # Weight the loss by max prediction
    margin_loss = margin_loss * max_pred 

    # Sum the loss over the batch
    total_margin_loss = margin_loss.mean()
    # pdb.set_trace()


    return total_margin_loss

class AD_Loss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.bce = nn.BCELoss()
        self.bce_ = nn.BCELoss(reduction='none')
        self.bce__ = nn.BCELoss(reduction='mean')

      
        
    def forward(self, result, _label, is_cluster, is_uncertainty, index, _idx, _cluster, tau_a, _au, un_th, alpha):
        loss = {}

        _label = _label.float()

        triplet = result["triplet_margin"]
        att = result['frame']
        A_att_ = result["A_att"]
        N_att = result["N_att"]
        A_Natt = result["A_Natt"]
        N_Aatt = result["N_Aatt"]
        kl_loss = result["kl_loss"]
        distance = result["distance"]
        b = _label.size(0)//2
        t = att.size(1) 
        device=att.device     
        if is_cluster:
            if index < 3001:
                anomaly_aa, selected_k, selected_indices_info = dynamic_k_(att[b:, :], _cluster, tau_a)
                anomaly_nn = torch.topk(att[:b, :], t//16, dim=-1)[0].mean(-1)
                a_temp = torch.cat((anomaly_nn, anomaly_aa), dim=0)
                anomaly_loss = self.bce(a_temp, _label)
            else:
                anomaly = torch.topk(att, t//16 + 1, dim=-1)[0].mean(-1)
                anomaly_loss = self.bce(anomaly, _label)
        else:
            anomaly = torch.topk(att, t//16 + 1, dim=-1)[0].mean(-1)
            anomaly_loss = self.bce(anomaly, _label)
        
        if is_uncertainty:
            un_bce = calculate_uncertainty(self, att[b:, :], _label[b:], _au, A_att_, N_Aatt, un_th)
            # if index >= 1000:
            #     alpha = 0.1
            # else:
            #     alpha = 0
            #     alpha = 0.1
            # elif index >= 999:
            #     alpha = 0.2
        else:
            un_bce = torch.tensor(0.0, device=att.device, requires_grad=True)
        
        # '''certain frames selection'''
        # if index > 1000:
        #     c_bce = calculate_ceratinty(self, att[:b, :], _label[:b], _au, A_att_, N_Aatt, u_th=0.1)
        #     beta = 0.2

        # else:
        #     c_bce = torch.tensor(0.0, device=att.device, requires_grad=True)
        #     beta = 0

        panomaly = torch.topk(1 - N_Aatt, t//16 + 1, dim=-1)[0].mean(-1)
        panomaly_loss = self.bce(panomaly, torch.ones((b)).to(device))
        
        if is_cluster:
            
            A_att = torch.stack([
                torch.topk(A_att_[i], 1)[0].mean(-1)
                for i in range(len(selected_k))
            
            ])


        else:
            A_att = torch.topk(A_att_, t//16 + 1, dim = -1)[0].mean(-1)
        A_loss = self.bce(A_att, torch.ones((b)).to(device))
        # pdb.set_trace()

        N_loss = self.bce(N_att, torch.ones_like((N_att)).to(device))   
        A_Nloss = self.bce(A_Natt, torch.zeros_like((A_Natt)).to(device))

        margin_loss = margin_loss_(self, att, _label)

        cost = anomaly_loss + alpha * un_bce + 0.1 * (A_loss + panomaly_loss + N_loss + A_Nloss) + 0.1 * triplet + 0.001 * kl_loss + 0.0001 * distance + 0.01*margin_loss

        loss['total_loss'] = cost
        loss['att_loss'] = anomaly_loss
        loss['un_loss'] = un_bce
        loss['margin_loss'] = margin_loss
        # loss['c_loss'] = c_bce
        loss['N_Aatt'] = panomaly_loss
        loss['A_loss'] = A_loss
        loss['N_loss'] = N_loss
        loss['A_Nloss'] = A_Nloss
        loss["triplet"] = triplet
        loss['kl_loss'] = kl_loss
        return cost, loss



def train(net, normal_loader, abnormal_loader, optimizer, criterion, wandb_viz, index, is_cluster=False, is_uncertainty=False, tau_a=0.9, un_th=0.5, alpha=0.5):
    net.train()
    net.flag = "Train"
    if is_cluster and is_uncertainty:
        ninput, nlabel, nidx, nc, nu = normal_loader
        ainput, alabel, aidx, ac, au = abnormal_loader
    elif is_cluster:
        ninput, nlabel, nidx, nc = normal_loader
        ainput, alabel, aidx, ac = abnormal_loader
        nu, au = None, None
    elif is_uncertainty:
        ninput, nlabel, nidx, nu = normal_loader
        ainput, alabel, aidx, au = abnormal_loader
        nc, ac = None, None
    else:
        ninput, nlabel, nidx = normal_loader
        ainput, alabel, aidx = abnormal_loader
        nc = None 
        ac = None 
        nu = None
        au = None
    device = next(net.parameters()).device

    _data = torch.cat((ninput, ainput), 0)
    _label = torch.cat((nlabel, alabel), 0)
    # _data = _data.cuda()
    # _label = _label.cuda()
    _data = _data.to(device)
    _label = _label.to(device)
    _idx = torch.cat((8100+nidx, aidx), 0)
    if ac is not None:
        _cluster = ac.detach()
    else:
        _cluster = ac

    if au is not None:
        _au = au.detach()
    _au = au

    predict = net(_data)
    cost, loss = criterion(predict, _label, is_cluster, is_uncertainty, index, _idx, _cluster, tau_a, _au, un_th, alpha)
    optimizer.zero_grad()
    cost.backward()
    optimizer.step()
    wandb_viz.run.log(loss)
