#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""

# This file contains definitions which are not applicable in regular scenarios.
# For general purposes functions, classes and operations - use helperfunctions.

import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from random import random
from typing import Optional

from matplotlib import pyplot as plt
from sklearn import metrics

# from extern.FilterResponseNormalizationLayer.frn import FRN, TLU


def concat_list_of_data_dicts(list_dicts):
    '''
    将数据字典列表连接起来。

    Parameters
    ----------
    list_dicts : list
        包含数据字典的列表

    Returns
    -------
    dict
        连接后的数据字典
    '''
    keys = list(list_dicts[0].keys())
    out_dict = {key: [] for key in keys}
    for key in keys:
        if type(list_dicts[0][key]) == list:
            out_dict[key] += [ele[key] for ele in list_dicts]
        else:
            out_dict[key] = torch.cat([ele[key] for ele in list_dicts], dim=0)
    return out_dict


def get_selected_set(data_dict, ds_num):
    '''
    获取指定数据集编号的数据子集。

    Parameters
    ----------
    data_dict : dict
        包含所有数据的字典
    ds_num : int
        数据集编号

    Returns
    -------
    dict
        指定数据集编号的数据子集
    '''
    all_ds = data_dict['ds_num']
    idx = [i for i, x in enumerate(all_ds) if x == ds_num]

    out_dict = {}
    for key, value in data_dict.items():
        if type(value) == list:
            out_dict[key] = [value[ele] for ele in idx]
        else:
            out_dict[key] = value[idx, ...]
    return out_dict


def create_meshgrid(height: int, width: int, normalized_coordinates: Optional[bool] = True) -> torch.Tensor:
    '''
    生成图像的坐标网格。

    当 `normalized_coordinates` 标志设置为 True 时，网格将被标准化为范围 [-1,1]，以与 PyTorch
    函数 grid_sample 保持一致。

    Args:
        height (int): 图像高度（行）。
        width (int): 图像宽度（列）。
        normalized_coordinates (Optional[bool]): 是否将坐标标准化为范围 [-1, 1]，以与 PyTorch
          函数 grid_sample 保持一致。

    Returns:
        torch.Tensor: 返回一个形状为 :math:`(1, H, W, 2)` 的网格张量。

    注意：函数摘自 Kornia 库。
    '''
    # 生成坐标
    xs: Optional[torch.Tensor] = None
    ys: Optional[torch.Tensor] = None
    if normalized_coordinates:
        xs = torch.linspace(-1, 1, width)
        ys = torch.linspace(-1, 1, height)
    else:
        xs = torch.linspace(0, width - 1, width)
        ys = torch.linspace(0, height - 1, height)

    # 确保 xs 和 ys 不需要梯度
    xs.requires_grad = False
    ys.requires_grad = False

    # 通过堆叠坐标生成网格
    base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys])).transpose(1, 2)  # 2xHxW
    return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1)  # 1xHxWx2



def get_nparams(model):
    '''
    计算模型的可训练参数数量。

    Parameters
    ----------
    model : torch.nn.Module
        模型对象。

    Returns
    -------
    int
        可训练参数的数量。
    '''
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_predictions(output):
    '''
    获取预测值的索引。

    Parameters
    ----------
    output : torch.tensor
        [B, C, *] 张量。返回独热编码的 argmax。

    Returns
    -------
    torch.tensor
        [B, *] 张量，包含索引。
    '''
    bs, c, h, w = output.size()
    values, indices = output.cpu().max(1)
    indices = indices.view(bs, h, w)  # bs x h x w
    return indices


class make_logger():
    def __init__(self, output_name, rank):
        self.rank_cond = rank == 0
        if self.rank_cond:
            dirname = os.path.dirname(output_name)
            if not os.path.exists(dirname):
                os.mkdir(dirname)
            self.dirname = dirname
            self.log_file = open(output_name, 'a+')

    def append(self, key, val):
        vals = self.infos.setdefault(key, [])
        vals.append(val)

    def log(self, info_dict, extra_msg=''):

        if self.rank_cond:
            msgs = [extra_msg]
            for key, vals in info_dict.iteritems():
                msgs.append('%s %.6f' % (key, np.mean(vals)))
            msg = '\n'.join(msgs)

            self.write(msg)
            print (msg)

    def write(self, msg, do_silent=False, do_warn=False):

        if self.rank_cond:
            msg = 'WARNING! '+msg if do_warn else msg
            self.log_file.write(msg+'\n')
            self.log_file.flush()

            if not do_silent:
                print (msg)

    def write_summary(self,msg):

        if self.rank_cond:
            self.log_file.write(msg)
            self.log_file.write('\n')
            self.log_file.flush()
            print (msg)

def get_seg_metrics(y_true, y_pred, cond, batch, frame):
    '''
    遍历每个批次，并确定哪些类别存在。如果没有类别存在，即全部为0，则从平均值中忽略该分数。
    注意：此函数计算 NaN 均值。这是因为数据集可能没有所有类别都出现。

    Parameters
    ----------
    y_true : numpy.ndarray
        实际标签，形状为 [B, *] 的矩阵。
    y_pred : numpy.ndarray
        预测标签，形状为 [B, *] 的矩阵。
    cond : numpy.ndarray
        布尔数组，指示每个批次是否有效。
    batch : int
        批次大小。
    frame : int
        帧数。

    Returns
    -------
    numpy.ndarray
        如果有任何有效批次，则返回分数列表，否则返回全零数组。
    '''
    assert y_pred.ndim == 3, 'Incorrect number of dimensions'
    assert y_true.ndim == 3, 'Incorrect number of dimensions'
    # # 定义类别颜色映射（示例：3个类别）
    # colormap = np.array([
    #     [0, 0, 0],  # 类别0：黑色（背景）
    #     [255, 0, 0],  # 类别1：红色
    #     [0, 255, 0]  # 类别2：绿色
    # ])
    #
    # # 将类别索引转换为 RGB 图像
    # mask_rgb = colormap[y_pred[0]]  # 形状变为 [H, W, 3]
    #
    # plt.imshow(mask_rgb)
    # plt.axis('off')
    # plt.show()

    cond = cond.astype(np.bool)
    B = y_true.shape[0]
    score_list = []

    for i in range(0, B):

        labels_present = np.unique(y_true[i, ...])
        # labels_present = labels_present[labels_present < 3]  # 只保留0,1,2
        score_vals = np.empty((3, ))
        score_vals[:] = np.nan

        if cond[i]:
            score = metrics.jaccard_score(y_true[i, ...].reshape(-1),
                                          y_pred[i, ...].reshape(-1),
                                          labels=labels_present,
                                          average=None)

            # Assign score to relevant location
            for j, val in np.ndenumerate(labels_present):
                score_vals[val] = score[j]

        score_list.append(score_vals)

    score_list = np.stack(score_list, axis=0)
    return score_list if np.any(cond) else np.zeros((len(cond), ))


def get_distance(y_true, y_pred, cond, metric='euclidean'):
    '''
    计算距离。

    Parameters
    ----------
    y_true : numpy.ndarray
        实际向量矩阵，形状为 [B, *]。
    y_pred : numpy.ndarray
        预测向量矩阵，形状为 [B, *]。
    cond : numpy.ndarray
        指示每个批次是否有效的布尔数组。
    metric : str, optional
        距离度量方式。默认为 'euclidean'。

    Returns
    -------
    numpy.ndarray
        距离。
    '''

    flag = cond.astype(np.bool)

    dist = np.linalg.norm(y_true - y_pred, axis=1)
    dist[~flag] = np.nan
    return dist if np.any(flag) else np.zeros((y_true.shape[0], ))



def getAng_metric(y_true, y_pred, cond):
    # 假设传入的角度测量值以弧度为单位
    flag = cond.astype(np.bool)
    dist = np.abs(y_true - y_pred)
    dist[~flag] = np.nan
    return dist if np.any(flag) else np.zeros((y_true.shape[0], ))


def normPts(pts, sz, by_max=False):
    if by_max:
        return 2*(pts/max(sz)) - 1
    else:
        return 2*(pts/sz) - 1

def unnormPts(pts, sz, by_max=False):
    if by_max:
        return 0.5*sz*(pts + 1)
    else:
        return 0.5*sz*(max(pts) + 1)

def compute_norm(model):
    list_of_norms = []
    for _, param in model.named_parameters():
        if param.grad is not None:
            list_of_norms.append(torch.norm(param.grad.detach(), 2).to('cpu'))
    list_of_norms = torch.stack(list_of_norms)
    total_norm = torch.norm(list_of_norms, 2)
    return total_norm

def points_to_heatmap(pts, std, res):
    # 给定图像分辨率和方差，为热图回归的兴趣点生成合成的高斯热图。
    # pts: [B, C, N, 2] 归一化点
    # H: [B, C, N, H, W] 输出热图
    B, C, N, _ = pts.shape
    pts = unnormPts(pts, res) #
    grid = create_meshgrid(res[0], res[1], normalized_coordinates=False)
    grid = grid.squeeze()
    X = grid[..., 0]
    Y = grid[..., 1]

    X = torch.stack(B*C*N*[X], axis=0).reshape(B, C, N, res[0], res[1])
    X = X - torch.stack(np.prod(res)*[pts[..., 0]], axis=3).reshape(B, C, N, res[0], res[1])

    Y = torch.stack(B*C*N*[Y], axis=0).reshape(B, C, N, res[0], res[1])
    Y = Y - torch.stack(np.prod(res)*[pts[..., 1]], axis=3).reshape(B, C, N, res[0], res[1])

    H = torch.exp(-(X**2 + Y**2)/(2*std**2))
    #H = H/(2*np.pi*std**2) # 这使得每个批次中的图像总和 == 1
    return H


def ElliFit(coords, mns):
    '''
    参数
    ----------
    coords：torch float32 [B, N, 2]
        椭圆周边的预测点
    mns：torch float32 [B, 2]
        中心点的预测均值

    返回
    -------
    PhiOp：与椭圆拟合相关的Phi分数。有关更多信息，请参阅ElliFit论文。
    '''
    B = coords.shape[0]  # 批量大小

    PhiList = []  # 用于存储每个批次的Phi值

    for bt in range(B):
        coords_norm = coords[bt, ...] - mns[bt, ...]  # 归一化坐标
        N = coords_norm.shape[0]  # 点的数量

        x = coords_norm[:, 0]  # x坐标
        y = coords_norm[:, 1]  # y坐标

        # 构建线性方程组的系数矩阵X和右侧向量Y
        X = torch.stack([-x**2, -x*y, x, y, -torch.ones(N, ).cuda()], dim=1)
        Y = y**2

        # 使用最小二乘法求解Phi参数
        a = torch.inverse(X.T.matmul(X))
        b = X.T.matmul(Y)
        Phi = a.matmul(b)
        PhiList.append(Phi)
    Phi = torch.stack(PhiList, dim=0)
    return Phi

def spatial_softmax_2d(input: torch.Tensor, temperature: torch.Tensor = torch.tensor(1.0)) -> torch.Tensor:
    r"""Applies the Softmax function over features in each image channel.
    Note that this function behaves differently to `torch.nn.Softmax2d`, which
    instead applies Softmax over features at each spatial location.
    Returns a 2D probability distribution per image channel.
    Arguments:
        input (torch.Tensor): the input tensor.
        temperature (torch.Tensor): factor to apply to input, adjusting the
          "smoothness" of the output distribution. Default is 1.
    Shape:
        - Input: :math:`(B, N, H, W)`
        - Output: :math:`(B, N, H, W)`
    """

    batch_size, channels, height, width = input.shape
    x: torch.Tensor = input.view(batch_size, channels, -1)

    x_soft: torch.Tensor = F.softmax(x * temperature, dim=-1)

    return x_soft.view(batch_size, channels, height, width)

class SpikeDetection():
    '''
    Custom spike detection module to skip learning on crappy batches
    尖峰检测模块，用于在不良批次上跳过学习
    '''
    def __init__(self,
                 patience=5,
                 threshold=2.5,
                 window_size=100,
                 ):

        import collections
        self.entries = collections.deque(maxlen=window_size)
        self.win_size = window_size
        self.threshold = threshold
        self.patience = patience
        self.count = 0

    def update(self, val):

        std_window = np.nanstd(self.entries)
        cond_std = np.abs(val - np.mean(self.entries)) > self.threshold*std_window

        if (len(self.entries) > self.win_size//2) and cond_std and (self.count > self.patience):
            is_spike = True
            self.count = 0
        else:
            self.count += 1
            is_spike = False
            self.entries.append(val)
        return is_spike


class EarlyStopping:
    """Early stops the training if validation loss/metric doesn't improve after a given patience."""
    # Modified by Rakshit Kothari.
    # Code taken from here: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
    def __init__(self,
                metric = None,
                patience=7,
                verbose=False,
                rank_cond=True,
                delta=0,
                mode='min',
                fName = 'checkpoint.pt',
                path_save = '/srv/beegfs02/scratch/aegis_cvl/data/dchristodoul/Results/checkpoints'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            fName (str): Name of the checkpoint file.
            path_save (str): Location of the checkpoint file.
        """
        self.patience = patience
        self.verbose = verbose if rank_cond else False
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf if mode == 'min' else -np.Inf
        self.delta = delta
        self.path_save = path_save
        self.fName = fName
        self.mode = mode
        self.rank_cond = rank_cond
        self.metric = metric

    def __call__(self, checkpoint):

        if '3D' in self.metric:
            val_score = checkpoint['valid_result']['gaze_3D_ang_deg_mean']
        elif '2D' in self.metric:
            val_score = checkpoint['valid_result']['gaze_3D_ang_deg_mean']
        else:
            val_score = checkpoint['valid_result']['gaze_3D_ang_deg_mean']
        score = -val_score if self.mode =='min' else val_score

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_score, checkpoint)

        elif score < (self.best_score + self.delta):
            self.counter += 1
            print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_score, checkpoint)
            self.counter = 0

    def save_checkpoint(self, val_score, model_dict,
                        update_val_score=True,
                        use_this_name_instead=False):

        '''Saves model when val_score decreases.'''
        if self.verbose and (self.mode == 'min'):
            print('Validation metric decreased ({:.6f} --> {:.6f}). Saving..'.format(self.val_loss_min,
                                                                                     val_score))
        elif self.verbose and (self.mode == 'max'):
            print('Validation metric increased ({:.6f} --> {:.6f}). Saving..'.format(self.val_loss_min,
                                                                                     val_score))
        if (self.rank_cond):
            if use_this_name_instead:
                torch.save(model_dict, os.path.join(self.path_save, use_this_name_instead))
            else:
                torch.save(model_dict, os.path.join(self.path_save, self.fName))

        if update_val_score:
            self.val_loss_min = val_score
        return



import torch
import torch.nn.functional as F
import numpy as np

def generate_pseudo_labels(input: torch.Tensor):
    r"""
    根据每个像素的熵生成伪标签

    Parameters
    ----------
    input : torch.Tensor
        分割网络的非softmax输出。

    Returns
    -------
    pseudo_labels : torch.Tensor
        根据分割输出生成的伪标签。
    confidence : torch.Tensor
        置信度，计算公式为 1 - (熵度量 / log K)；
        其中 K 是类别数。
    """
    num_classes = input.shape[1]
    sc = np.log(num_classes)

    logits = F.softmax(input, dim=1)
    logits_log = F.log_softmax(input, dim=1)  # 在对数几率上使用 log 比在对数上使用更好
    entropy = -torch.sum(logits*logits_log, dim=1).detach()

    pseudo_labels = torch.argmax(logits, dim=1).detach().to(torch.long)

    return pseudo_labels, 1 - (entropy/sc)


def remove_underconfident_psuedo_labels(conf,
                                        label_tracker=False,
                                        gt_dict=False):
    """
    根据置信度阈值移除不自信的伪标签。

    Parameters:
    -----------
    conf : torch.Tensor
        每个伪标签的置信度值。
    label_tracker : object or bool, optional
        用于跟踪标签的对象。默认为 False。
    gt_dict : dict or bool, optional
        包含地面真实信息的字典。默认为 False。

    Returns:
    --------
    mask : torch.Tensor
        布尔掩码，指示哪些伪标签足够自信。
    """
    # 如果没有提供标签跟踪器，则关闭地面真实信息
    gt_dict = gt_dict if label_tracker else False

    if gt_dict:
        # 这里与 gt_dict 有关的一些操作；看起来是不完整的片段

        return conf > label_tracker.threshold
    else:
        return conf > 0.95


def spatial_softargmax_2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor:
    """
    计算给定输入热图的2D软最大值。
    假定输入热图表示有效的空间概率分布。

    Arguments:
    ----------
        input : torch.Tensor
            表示热图的输入张量。
        normalized_coordinates : bool, optional
            是否返回在[-1, 1]范围内归一化的坐标。
            否则，将返回在输入形状范围内的坐标。
            默认为 True。

    Shape:
    ------
        - Input: :math:`(B, N, H, W)`
        - Output: :math:`(B, N, 2)`

    Examples:
    ---------
        >>> heatmaps = torch.tensor([[[
            [0., 0., 0.],
            [0., 0., 0.],
            [0., 1., 0.]]]])
        >>> coords = spatial_softargmax_2d(heatmaps, False)
        tensor([[[1.0000, 2.0000]]])
    """
    batch_size, channels, height, width = input.shape

    # 创建坐标网格。
    grid: torch.Tensor = create_meshgrid(
        height, width, normalized_coordinates)
    grid = grid.to(device=input.device, dtype=input.dtype)

    pos_x: torch.Tensor = grid[..., 0].reshape(-1)
    pos_y: torch.Tensor = grid[..., 1].reshape(-1)

    input_flat: torch.Tensor = input.view(batch_size, channels, -1)

    # 计算坐标的期望值。
    expected_y: torch.Tensor = torch.sum(pos_y * input_flat, -1, keepdim=True)
    expected_x: torch.Tensor = torch.sum(pos_x * input_flat, -1, keepdim=True)

    output: torch.Tensor = torch.cat([expected_x, expected_y], -1)

    return output.view(batch_size, channels, 2)  # BxNx2

def soft_heaviside(x, sc, mode):
    '''
    给定输入和缩放因子（默认为 64），软海维赛德函数以可微的方式近似了 0 或 1 操作的行为。
    请注意海维赛德函数中的最大值已经缩放为 0.9。
    这种缩放是为了方便和稳定性，特别是与 bCE 损失一起使用时。
    '''
    sc = torch.tensor([sc]).to(torch.float32).to(x.device)
    if mode==1:
        # 原始软海维赛德函数
        # 尝试 sc = 64
        return 0.9/(1 + torch.exp(-sc/x))
    elif mode==2:
        # 一些奇特但具有良好梯度的函数
        # 尝试 sc = 0.001
        return 0.45*(1 + (2/np.pi)*torch.atan2(x, sc))
    elif mode==3:
        # 标准的缩放 sigmoid 函数。未来：使 sc 成为自由参数
        # 尝试 sc = 8
        return torch.sigmoid(sc*x)
    else:
        print('未定义的模式')


def _assert_no_grad(variables):
    '''
    断言变量不需要梯度。
    '''
    for var in variables:
        assert not var.requires_grad, \
            "nn criterions don't compute the gradient w.r.t. targets - please " \
            "mark these variables as volatile or not requiring gradients"


def move_to_multi(model_dict):
    '''
    将权重和键的字典转换为多 GPU 格式。
    它只是在键的前面添加了 'module.'。
    '''
    multiGPU_dict = {}
    for key, value in model_dict.items():
        multiGPU_dict['module.'+key] = value
    return multiGPU_dict



def move_to_single(model_dict, move_to_cpu=True):
    '''
    将权重和键的字典转换为单 GPU 格式。它删除键前面的 'module.'。
    另外，它将所有权重移动到 CPU，以防万一。
    '''
    singleGPU_dict = {}
    for key, value in model_dict.items():
        if move_to_cpu:
            singleGPU_dict[key.replace('module.', '')] = value.cpu()
        else:
            singleGPU_dict[key.replace('module.', '')] = value
    return singleGPU_dict


def detach_cpu_np(ip):
    '''
    分离张量并将其转换为 numpy 数组，同时将其移动到 CPU。
    '''
    return ip.detach().cpu().numpy()


def generaliz_mean(tensor, dim, p=-9, keepdim=False):
    '''
    计算沿某些轴的 generaliz mean。
    generaliz mean 在 p = -inf 时对应于最小值。
    https://en.wikipedia.org/wiki/Generalized_mean
    :param tensor: 任意维度的张量。
    :param dim: (int 或 int 元组) 要减少的维度。
    :param keepdim: (bool) 输出张量是否保留 dim。
    :param p: (float<0)。
    '''
    assert p < 0
    res = torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p)
    return res


def do_nothing(input):
    '''
    不执行任何操作，直接返回输入。
    '''
    return input

#
# class FRN_TLU(torch.nn.Module):
#     def __init__(self, channels, track_running_stats=False):
#         super(FRN_TLU, self).__init__()
#         self.FRN = FRN(channels, is_eps_learnable=True)
#         self.TLU = TLU(channels)
#
#     def forward(self, x):
#         x = self.FRN(x)
#         x = self.TLU(x)
#         return x
