"""
@Description :   局部特征匹配后分类器
@Author      :   tqychy 
@Time        :   2025/02/15 12:02:30
"""
import numpy as np
import torch
import torch.nn as nn
from numba import jit


class LinearClassify(nn.Module):
    def __init__(self, cfg, logger):
        super().__init__()
        self.cfg = cfg
        self.logger = logger
        self.seq_len = self.cfg.DATASET.CONTOUR_MAX_LEN
        self.dim = self.cfg.NET.FEATURE_EXTRACT_DIM

        # 定义一个简单的注意力层，用于计算每个轮廓点的重要性得分
        self.attention_fc = nn.Linear(self.dim, 1)

        # 特征融合后通道数：f1, f2, 以及它们的绝对差值
        combined_dim = self.dim * 3
        hidden_dim = self.dim

        self.fc1 = nn.Linear(combined_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)  # 输出一个 logit
        self.dropout = nn.Dropout(0.5)
        self.activation = nn.ReLU()

    def weighted_pooling(self, f):
        """
        使用注意力权重对轮廓点特征进行加权池化。
        f: (bs, seq_len, dim)
        """
        # 计算每个轮廓点的重要性得分，shape: (bs, seq_len, 1)
        attn_scores = self.attention_fc(f)
        attn_weights = nn.functional.softmax(attn_scores, dim=1)
        # 加权求和，得到 (bs, dim)
        f_pool = torch.sum(f * attn_weights, dim=1)
        return f_pool

    def forward(self, f1, f2):
        # f1, f2: (bs, seq_len, dim)
        # 对每个碎片进行注意力加权池化，得到 (bs, dim) 的全局特征
        f1_pool = self.weighted_pooling(f1)
        f2_pool = self.weighted_pooling(f2)

        # 计算两个全局特征的绝对差值
        diff = torch.abs(f1_pool - f2_pool)  # (bs, dim)

        # 拼接 f1_pool, f2_pool 以及 diff，得到 (bs, 3*dim) 的特征向量
        combined = torch.cat([f1_pool, f2_pool, diff], dim=1)

        x = self.fc1(combined)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)  # (bs, 1)
        return x.squeeze()


class CNNScoreEvaluator(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args
        self.seq_len = self.cfg.DATASET.CONTOUR_MAX_LEN

        # 定义 CNN 网络
        self.cnn = nn.Sequential(
            # 输出: bs * 32 * seq_len * seq_len
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),  # 输出: bs * 32 * (seq_len/2) * (seq_len/2)
            # 输出: bs * 64 * (seq_len/2) * (seq_len/2)
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),  # 输出: bs * 64 * (seq_len/4) * (seq_len/4)
            # 输出: bs * 128 * (seq_len/4) * (seq_len/4)
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.AdaptiveAvgPool2d((1, 1)),  # 输出: bs * 128 * 1 * 1
        )
        self.fc = nn.Linear(128, 1)  # 输出: bs * 1
        self.sigmoid = nn.Sigmoid()

    def forward(self, preds, _):
        """
        输入:
            S: torch.Tensor, shape [bs, seq_len, seq_len], 相似度矩阵
            pad_mask: torch.Tensor, shape [bs, 2, seq_len], 填充掩码
        输出:
            logits: torch.Tensor, shape [bs], 每个样本的 logit 值
        """
        # 增加通道维度
        M = preds.unsqueeze(1)  # bs * 1 * seq_len * seq_len
        # 通过 CNN 提取特征
        features = self.cnn(M)  # bs * 128 * 1 * 1
        # 展平特征
        features = features.view(features.size(0), -1)  # bs * 128
        # 全连接层输出 logit
        logits = self.fc(features)  # bs * 1
        return self.sigmoid(logits.squeeze(1))  # bs


class MaxDiagScoreEvalutor(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args

    def forward(self, preds, length):
        preds = preds.cpu().numpy()
        scores = np.zeros(preds.shape[0])
        for batch, pred in enumerate(preds):
            scores[batch] = 20 * self.get_score(np.array(pred)) / length[batch]
        return torch.tensor(scores)

    @staticmethod
    @jit(nopython=True)
    def get_score(mat: np.ndarray) -> int:
        n = mat.shape[0]
        max_len = 0

        # 沿列方向翻转矩阵
        mat_flipped = mat[:, ::-1]

        # 遍历所有对角线偏移量
        for k in range(-n + 1, n):
            # 根据偏移量 k 提取对角线
            if k >= 0:
                diag_len = n - k
                diag = np.zeros(diag_len, dtype=np.int64)
                for i in range(diag_len):
                    diag[i] = mat_flipped[i, i + k]
            else:
                diag_len = n + k
                diag = np.zeros(diag_len, dtype=np.int64)
                for i in range(diag_len):
                    diag[i] = mat_flipped[i - k, i]

            # 检查对角线情况
            if np.all(diag == 1):
                current_max = diag_len
            elif not np.any(diag == 1):
                current_max = 0
            else:
                # 找到所有非 1 的位置
                indices = np.where(diag != 1)[0]
                # 添加首尾边界
                indices = np.concatenate(
                    (np.array([-1]), indices, np.array([diag_len])))
                # 计算间隔并取最大值
                diffs = indices[1:] - indices[:-1]
                current_max = np.max(diffs - 1)

            # 更新最大长度
            if current_max > max_len:
                max_len = current_max

        return max_len
