# Insert the path into sys.path for importing.
import sys,os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import numpy as np
from sklearn.metrics import auc
import torch.nn.functional as F
import torch

def MSEContrastiveLoss(reconstructed, target, embeddings, labels, alpha=0.2):
    """
    Args:
        reconstructed: [B, T, D], decoder output
        target:        [B, T, D], target sequence
        embeddings:    [B, D],   encoder output embeddings
        labels:        [B],      0/1 for anomaly detection
        alpha:         float,    weight for contrastive loss
    """
    # 1. Reconstruction loss (MSE)
    recon_loss = F.mse_loss(reconstructed, target)
    
    # 2. Contrastive loss (Supervised Contrastive Loss)
    contrastive_loss = SupervisedContrastiveLoss(embeddings, labels)

    # 3. Total
    return (1-alpha) * recon_loss + alpha * contrastive_loss

def SupervisedContrastiveLoss(embeddings, labels, reconstructed=None, target=None, temperature=0.1):
    device = embeddings.device
    B = embeddings.size(0)
    z = F.normalize(embeddings, p=2, dim=1)  # [B, D]

    # --- 相似度矩阵 [B, B] ---
    sim_matrix = torch.matmul(z, z.T) / temperature  # cosine sim
    sim_matrix = sim_matrix - torch.eye(B, device=device) * 1e9  # 屏蔽对角线

    # --- 构造正样本 mask [B, B] ---
    labels = labels.view(-1, 1)  # [B, 1]
    mask = torch.eq(labels, labels.T).float().to(device)
    mask = mask - torch.eye(B, device=device)  # 去除自己本身

    # --- softmax denom: 所有非自己样本 ---
    sim_matrix = torch.exp(sim_matrix)
    sim_sum = sim_matrix.sum(dim=1, keepdim=True)  # [B, 1]

    # --- numerator: 只保留正样本项 ---
    pos_sim = sim_matrix * mask  # [B, B]
    pos_sum = pos_sim.sum(dim=1)

    # --- loss ---
    loss = -torch.log(pos_sum / (sim_sum.squeeze(1) + 1e-8))
    loss = loss[mask.sum(dim=1) > 0]  # 过滤无正样本者
    if len(loss) == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)
    return loss.mean()

def MSELoss(reconstructed, target, embeddings=None, labels=None):
    return F.mse_loss(reconstructed,target)

def WeightedMSELoss(reconstructed, target, embeddings=None, labels=None):
    loss = (reconstructed - target) ** 2
    device = reconstructed.device  # 获取 src 所在设备
    feature_weights = torch.tensor(
        [1.0] * 21 + [1.5, 1.5, 1.5, 1.5, 2, 2, 2, 2, 2.5, 2.5, 2.5, 2.5],
        device=device  # 保证在同一设备
    )
    loss = loss * feature_weights.view(1, 1, -1)
    return loss.mean()

def FreqLoss(reconstructed, target, embeddings=None, labels=None, alpha = 0.8):     
    mse_loss = F.mse_loss(reconstructed,target)
    loss_freq = (torch.fft.rfft(reconstructed, dim=1) - torch.fft.rfft(target, dim=1)).abs().mean()
    return alpha * mse_loss + (1-alpha) * loss_freq

def compute_confusion_stats(test_data, ids):
    """
    Computes number of True Positives (TP), True Negatives (TN), False Positives (FP) and False Negatives (FN) based
    on flagged ids on test_data.

    Parameters
    ----------
    test_data : pandas df
        Test data set on which to compute metrics
    ids: list of int
        List of ids detected by a model

    Returns
    -------
    tp: int
        Number of true positives
    tn: int
        Number of true negatives
    fp: int
        Number of false positives
    fn: int
        Number of false negatives
    """
    # Adding flagged ids to test_data
    ids = np.unique(ids)
    test_data['Flag'] = 0
    test_data.loc[test_data['index'].isin(ids), 'Flag'] = 1

    # Computing tp, tn, fp and fn
    tp = test_data[(test_data['FraudType'] != 0) & (test_data['Flag'] == 1)].shape[0]
    tn = test_data[(test_data['FraudType'] == 0) & (test_data['Flag'] == 0)].shape[0]
    fp = test_data[(test_data['FraudType'] == 0) & (test_data['Flag'] == 1)].shape[0]
    fn = test_data[(test_data['FraudType'] != 0) & (test_data['Flag'] == 0)].shape[0]

    return tp, tn, fp, fn

def performance_matrix(test_set, detection_type, f_score_beta=4):
    precisions = []
    recalls = []
    tprs = []
    fprs = []
    max_threshold = test_set['score'].max()
    min_threshold = test_set['score'].min()
    thresholds = np.linspace(min_threshold, max_threshold, 100)
    for threshold in thresholds:
        ids = test_set.loc[test_set['score'] > threshold, 'index'].values
        tp, tn, fp, fn = compute_confusion_stats(test_set.copy(), ids)

        recall = tp / (tp + fn)
        if tp + fp != 0:
            precision = tp / (tp + fp)
        else:
            precision = 0
        tpr = recall
        fpr = fp / (fp + tn)

        precisions.append(precision)
        recalls.append(recall)
        tprs.append(tpr)
        fprs.append(fpr)

    final_results = {}
    roc = auc(fprs, tprs)
    final_results[detection_type+'_AUROC'] = str(round(roc, 3))
    print("AUROC: " + str(round(roc, 3)))

    pr = auc(recalls, precisions)
    final_results[detection_type+'_AUC-PR'] = str(round(pr, 3))
    print("AUC-PR: " + str(round(pr, 3)))
    
    f_scores = np.array([(1 + f_score_beta ** 2) * (precision * recall) / (f_score_beta ** 2 * precision + recall) for precision, recall
                         in zip(precisions, recalls) if precision + recall != 0])
    f_score = np.max(f_scores)
    precision = precisions[np.argmax(f_scores)]
    recall = recalls[np.argmax(f_scores)]
    # Finding best tau
    best_tau = thresholds[np.argmax(f_scores)]
    final_results[detection_type+'_F_score'] = str(round(f_score, 3))
    final_results[detection_type+'_Precision'] = str(round(precision, 3))
    final_results[detection_type+'_Recall'] = str(round(recall, 3))
    final_results[detection_type+'_Best_Threshold'] = str(round(best_tau, 3))
    print('F score: ' + str(round(f_score, 3)))
    print('Precision: ' + str(round(precision, 3)))
    print('Recall: ' + str(round(recall, 3)))
    print('Best Threshold: ' + str(round(best_tau, 3)))

    return final_results


def manipulated_level_metrics(test_set, detection_type):
    """
    Calculate metrics (precision, recall, F-score, confusion matrix) for specified ManipulatedLevel(s),
    using the same logic as performance_matrix, but restricting all confusion matrix statistics to those levels.
    """
    if 'score' not in test_set.columns or 'FraudType' not in test_set.columns or 'ManipulatedLevel' not in test_set.columns:
        raise ValueError('test_set must contain score, FraudType, and ManipulatedLevel columns')
    precisions = []
    recalls = []
    tprs = []
    fprs = []
    max_threshold = test_set['score'].max()
    min_threshold = test_set['score'].min()
    thresholds = np.linspace(min_threshold, max_threshold, 100)
    # Only exclude ManipulatedLevel==1, include NaN and all others
    level_mask = (test_set['ManipulatedLevel'] != 1)
    for threshold in thresholds:
        # 只关注2-5档的混淆矩阵
        ids = test_set.loc[test_set['score'] > threshold, 'index'].values
        flagged = set(ids)
        # 只统计2-5档的样本
        subset = test_set[level_mask].copy()
        subset['Flag'] = subset['index'].apply(lambda x: 1 if x in flagged else 0)
        tp = subset[(subset['FraudType'] != 0) & (subset['Flag'] == 1)].shape[0]
        tn = subset[(subset['FraudType'] == 0) & (subset['Flag'] == 0)].shape[0]
        fp = subset[(subset['FraudType'] == 0) & (subset['Flag'] == 1)].shape[0]
        fn = subset[(subset['FraudType'] != 0) & (subset['Flag'] == 0)].shape[0]
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        tpr = recall
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
        precisions.append(precision)
        recalls.append(recall)
        tprs.append(tpr)
        fprs.append(fpr)
    final_results = {}
    roc = auc(fprs, tprs)
    final_results[detection_type+'_AUROC_L2345'] = str(round(roc, 3))
    pr = auc(recalls, precisions)
    final_results[detection_type+'_AUC-PR_L2345'] = str(round(pr, 3))
    f_score_beta = 4
    f_scores = np.array([(1 + f_score_beta ** 2) * (precision * recall) / (f_score_beta ** 2 * precision + recall) for precision, recall in zip(precisions, recalls) if precision + recall != 0])
    f_score = np.max(f_scores) if len(f_scores) > 0 else 0.0
    if len(f_scores) > 0:
        best_idx = np.argmax(f_scores)
        precision = precisions[best_idx]
        recall = recalls[best_idx]
        best_tau = thresholds[best_idx]
    else:
        precision = 0.0
        recall = 0.0
        best_tau = thresholds[0]
    final_results[detection_type+'_F_score_L2345'] = str(round(f_score, 3))
    final_results[detection_type+'_Precision_L2345'] = str(round(precision, 3))
    final_results[detection_type+'_Recall_L2345'] = str(round(recall, 3))
    final_results[detection_type+'_Best_Threshold_L2345'] = str(round(best_tau, 3))
    print(f"[L2345] AUROC: {roc:.3f}, AUC-PR: {pr:.3f}, F-score: {f_score:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, Best Threshold: {best_tau:.3f}")
    return final_results
