import os
import copy
from random import random
from typing import Optional

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from skimage import draw
from sklearn import metrics




def get_com(seg_map, temperature=4):
    # 自定义函数用于找到质心以获取检测到的瞳孔或虹膜的中心
    # 质心位于标准化空间中，取值范围为 -1 到 1。
    # seg_map: BXHXW - 单通道，对应瞳孔或虹膜预测

    device = seg_map.device

    B, H, W = seg_map.shape
    wtMap = F.softmax(seg_map.view(B, -1)*temperature, dim=1)  # [B, HXW]

    XYgrid = create_meshgrid(H, W, normalized_coordinates=True)  # 1xHxWx2
    xloc = XYgrid[0, :, :, 0].reshape(-1).to(device)
    yloc = XYgrid[0, :, :, 1].reshape(-1).to(device)

    xpos = torch.sum(wtMap*xloc, -1, keepdim=True)
    ypos = torch.sum(wtMap*yloc, -1, keepdim=True)
    predPts = torch.stack([xpos, ypos], dim=1).squeeze()

    return predPts

# 根据椭圆参数生成掩码
def construct_mask_from_ellipse(ellipses, res):
    # 检查椭圆参数数组的形状，如果是一维数组，则扩展为二维数组
    if len(ellipses.shape) == 1:
        ellipses = ellipses[np.newaxis, :]
        B = 1  # 扩展后的批次大小为1
    else:
        B = ellipses.shape[0]  # 获取批次大小

    # 初始化掩码数组，形状为(B, res[0], res[1])
    mask = np.zeros((B, ) + res)
    # 遍历每个批次中的椭圆参数
    for b in range(B):
        ellipse = ellipses[b, ...].tolist()  # 将椭圆参数转换为列表格式
        # 调用 draw.ellipse 函数绘制椭圆，并将像素位置标记为1
        [rr, cc] = draw.ellipse(round(ellipse[1]),  # y_center
                                round(ellipse[0]),  # x_center
                                round(ellipse[3]),  # r1 (major axis)
                                round(ellipse[2]),  # r2 (minor axis)
                                shape=res,  # 图像分辨率
                                rotation=-ellipse[4])  # 旋转角度
        # 将行和列索引限制在图像范围内
        rr = np.clip(rr, 0, res[0]-1)
        cc = np.clip(cc, 0, res[1]-1)
        # 在掩码数组中标记椭圆区域为True（1）
        mask[b, rr, cc] = 1
    return mask.astype(bool)  # 将掩码数组转换为布尔类型并返回


# 修正椭圆的轴长和角度
def fix_ellipse_axis_angle(ellipse):
    # 创建椭圆参数的深层副本
    ellipse = copy.deepcopy(ellipse)
    # 如果长轴大于短轴，则交换长轴和短轴的值，并将角度增加90度
    if ellipse[3] > ellipse[2]:
        ellipse[[2, 3]] = ellipse[[3, 2]]  # 交换长轴和短轴的值
        ellipse[4] += np.pi/2  # 将角度增加90度（π/2）

    # 将角度限制在0到π之间
    if ellipse[4] > np.pi:
        ellipse[4] += -np.pi
    elif ellipse[4] < 0:
        ellipse[4] += np.pi

    return ellipse  # 返回修正后的椭圆参数


# 一个用于生成网格坐标的网络
def create_meshgrid(
        height: int,
        width: int,
        normalized_coordinates: Optional[bool] = True) -> torch.Tensor:
    """Generates a coordinate grid for an image.

    When the flag `normalized_coordinates` is set to True, the grid is
    normalized to be in the range [-1,1] to be consistent with the pytorch
    function grid_sample.
    http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample

    Args:
        height (int): the image height (rows).
        width (int): the image width (cols).
        normalized_coordinates (Optional[bool]): whether to normalize
          coordinates in the range [-1, 1] in order to be consistent with the
          PyTorch function grid_sample.

    Return:
        torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`.

    NOTE: Function taken from Kornia libary
    """
    # generate coordinates 生成坐标
    xs: Optional[torch.Tensor] = None
    ys: Optional[torch.Tensor] = None
    if normalized_coordinates:
        xs = torch.linspace(-1, 1, width)
        ys = torch.linspace(-1, 1, height)
    else:
        xs = torch.linspace(0, width - 1, width)
        ys = torch.linspace(0, height - 1, height)

    # Ensure that xs and ys do not require gradients
    xs.requires_grad = False
    ys.requires_grad = False

    # generate grid by stacking coordinates   # 通过堆叠坐标生成网格
    base_grid: torch.Tensor = torch.stack(
        torch.meshgrid([xs, ys])).transpose(1, 2)  # 2xHxW
    return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1)  # 1xHxWx2


def getSizes(chz, growth, blks=4):
    """根据基本通道大小、增长率和块数量，计算所有层的输入和输出通道大小。"""
    # For a base channel size, growth rate and number of blocks,
    # this function computes the input and output channel sizes for
    # al layers.

    # Encoder sizes
    sizes = {'enc': {'inter': [], 'ip': [], 'op': []},
             'dec': {'skip': [], 'ip': [], 'op': []}}
    sizes['enc']['inter'] = np.array([chz*(i+1) for i in range(0, blks)])
    sizes['enc']['op'] = np.array([int(growth*chz*(i+1)) for i in range(0, blks)])
    sizes['enc']['ip'] = np.array([chz] + [int(growth*chz*(i+1)) for i in range(0, blks-1)])

    # Decoder sizes
    sizes['dec']['skip'] = sizes['enc']['ip'][::-1] + sizes['enc']['inter'][::-1]
    sizes['dec']['ip'] = sizes['enc']['op'][::-1]
    sizes['dec']['op'] = np.append(sizes['enc']['op'][::-1][1:], chz)
    return sizes


