import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

import sys

sys.path.append('..')

from models.resnet_encoder import resnet18
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d

class FeatureExtractor(nn.Module):
    def __init__(self,args):
        super(FeatureExtractor, self).__init__()

        self.enc = resnet18(pretrained=args['pretrained_resnet'],
                            **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # 选择自适应池化层，使用平均池化，扁平化输出，输入格式为NCHW
        self.global_pool = SelectAdaptivePool2d(
            pool_type='avg',
            flatten=True,
            input_fmt='NCHW',
        )
            # 如果启用了网络的rend头部或simply头部
        # 设置嵌入维度为256
        embed_dim = 512
        # 定义一个线性层，将输入特征2048降维到embed_dim（256），并且使用偏置
        # self.rend_head_1 = nn.Linear(in_features=2048, out_features=embed_dim, bias=True)

    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        # Move image data to GPU
        x = data_dict['image'].to(torch.float32).to(self.device,
                                                    non_blocking=True)
        B, N, H, W = x.shape
        # Merge batch and window size into one dimension
        x = rearrange(x, 'B N H W -> (B N) 1 H W')
        latent_org, enc_op_org = self.enc(x)
        latent = self.global_pool(latent_org)

        # Un-merge batch and window size
        # TODO Rearange x, latent, enc_op
        eye_enc_out = rearrange(latent, '(B N) C-> B N C', B=B, N=N)

        # eye_3d_out = self.rend_head_1(latent)

        return eye_enc_out, latent

class res_18_gpm(nn.Module):
    def __init__(self,
                 args,
                 act_func=F.leaky_relu,
                 norm=nn.BatchNorm2d):
        super(res_18_gpm, self).__init__()

        self.N_win = args['frames']
        self.net_rend_head = args['net_rend_head']
        self.net_simply_head = args['net_simply_head']
        embed_dim = 512
        assert (self.net_rend_head or self.net_simply_head)
        # self.enc = resnet18(pretrained=args['pretrained_resnet'],
        #                     **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # assert not self.net_ellseg_head  # TODO Enable using the ellseg head (adapt feature pyramid size)
        self.featureExtractor = FeatureExtractor(args)
        # net_dict = torch.load("cur_objs/pretrained/myPara.pt",
        #                       map_location=torch.device('cuda'))
        # state_dict_single = move_to_single(net_dict['state_dict'])
        # self.featureExtractor.load_state_dict(state_dict_single, strict=False)
        # for param in self.featureExtractor.parameters():
        #     param.requires_grad = False
        self.fc = nn.Linear(embed_dim, 3, bias=True)  # 第一个线性层

        self.all_features = []
        self.GT = []


    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        out_dict_gaze = {}
        eye_feature, eye_feature2 = self.featureExtractor(data_dict, args)
        # new
        gaze_predict = self.fc(eye_feature)
        norm_3D_gaze = gaze_predict / (
                        torch.norm(gaze_predict, dim=-1, keepdim=True) + 1e-9)
        out_dict_gaze['gaze_vector_3D'] = rearrange(norm_3D_gaze, 'B N d -> (B N) d')

        # Todo  visual
        # batch_size1, frames1, dim1 = eye_feature.shape
        # # 重新排列特征为 (batch_size * frames, dim)，适配 Isomap 输入格式
        # reshaped_features = rearrange(eye_feature, 'B N d -> (B N) d', B=batch_size1, N=frames1)
        # # 进行 Isomap 降维
        # isomap = Isomap(n_neighbors=100, n_components=3)
        # IP_features2 = isomap.fit_transform(reshaped_features.cpu().detach().numpy())
        #
        # # IP_features = rearrange(projected_features_Pre, 'B N d -> (B N) d').detach().cpu().numpy()
        # IP_features = projected_features_Pre.detach().cpu().numpy()
        # self.all_features.append(IP_features)  # 存入列表
        # # projected_features_GT = torch.tensor(IP_features, dtype=torch.float32, device='cuda',
        # #                                      requires_grad=True)
        #
        # reshaped_gt = data_dict["gaze_vector"].reshape(batch_size1 * frames1, 3)
        # self.GT.append(reshaped_gt.cpu().numpy())  # 存入列表
        # # self.GT.append(IP_features2)  # 存入列表
        #
        # if len(self.all_features) >= 10:
        #     projected_features = np.vstack(self.all_features)  # 合并所有 batch，形成大样本集
        #
        #     gt_all = np.vstack(self.GT)  # 合并所有 batch，形成大样本集
        #
        #     # 计算投影数据的最大范数，作为球体半径
        #     max_radius = np.max(np.linalg.norm(projected_features, axis=1))
        #
        #     # 归一化到单位球体
        #     norms = np.linalg.norm(projected_features, axis=1, keepdims=True)
        #
        #     # unit_sphere_features = projected_features / norms
        #     unit_sphere_features = projected_features
        #
        #     norms_gt = np.linalg.norm(gt_all, axis=1, keepdims=True)
        #     # gt_all_norm = gt_all / norms_gt
        #     gt_all_norm = gt_all
        #
        #     print(unit_sphere_features.shape)
        #     print(gt_all_norm.shape)
        #
        #     # 计算GT数据的角度范围
        #     yaw_angles_gt = np.arctan2(gt_all_norm[:, 1], gt_all_norm[:, 0])
        #     pitch_angles_gt = np.arcsin(gt_all_norm[:, 2])
        #
        #     # # 获取GT数据的角度范围
        #     # gt_yaw_min, gt_yaw_max = np.min(yaw_angles_gt), np.max(yaw_angles_gt)
        #     # gt_pitch_min, gt_pitch_max = np.min(pitch_angles_gt), np.max(pitch_angles_gt)
        #
        #     # 计算颜色映射 (Yaw & Pitch 假设)
        #     yaw_angles = np.arctan2(unit_sphere_features[:, 1], unit_sphere_features[:, 0])  # Yaw 角 (水平)
        #     pitch_angles = np.arcsin(unit_sphere_features[:, 2])  # Pitch 角 (垂直)
        #
        #     # # 限制 yaw 和 pitch 在特定范围内
        #     # yaw_range = np.logical_and(yaw_angles >= np.deg2rad(50),
        #     #                            yaw_angles <= np.deg2rad(150))  # Yaw 在 [-50°, 50°]
        #     # pitch_range = np.logical_and(pitch_angles >= np.deg2rad(-50),
        #     #                              pitch_angles <= np.deg2rad(50))  # Pitch 在 [-50°, 50°]
        #     # # 根据 Yaw 和 Pitch 的范围限制数据
        #     # # unit_sphere_features = unit_sphere_features[yaw_range]
        #     # # unit_sphere_features_pitch = unit_sphere_features[pitch_range]
        #     #
        #     # # yaw_angles = yaw_angles[yaw_range]  # 颜色映射
        #     # # pitch_angles = pitch_angles[pitch_range]
        #     # 根据GT的角度范围筛选预测数据
        #     # valid_range = np.logical_and(yaw_angles >= gt_yaw_min, yaw_angles <= gt_yaw_max)
        #
        #     # unit_sphere_features = unit_sphere_features[valid_range]
        #     # yaw_angles = yaw_angles[valid_range]
        #     # pitch_angles = pitch_angles[valid_range]
        #     # 绘制可视化
        #     fig = plt.figure(figsize=(14, 8))
        #
        #     # 归一化到单位球面上的投影
        #     ax1 = fig.add_subplot(231, projection='3d')
        #     plot_transparent_sphere(ax1, max_radius)  # 添加半透明球体背景
        #     ax1.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], unit_sphere_features[:, 2],
        #                 c=yaw_angles, cmap='jet', alpha=0.8, s=5)
        #     # # 添加切割平面，模拟图中的效果
        #     # # 假设平面是通过z轴来切割的，设置切割平面
        #     # plane_x = np.linspace(-1, 1, 10)
        #     # plane_y = np.linspace(-1, 1, 10)
        #     # X, Y = np.meshgrid(plane_x, plane_y)
        #     # Z = 0.5 * np.ones_like(X)  # 设置平面高度，模拟切割效果
        #     # ax1.plot_surface(X, Y, Z, color='gray', alpha=0.5)
        #     ax1.set_title('3D View')
        #     ax1.set_xlabel("X")
        #     ax1.set_ylabel("Y")
        #     ax1.set_zlabel("Z")
        #
        #     # 俯视图
        #     ax2 = fig.add_subplot(232)
        #     ax2.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], c=yaw_angles, cmap='jet',
        #                 alpha=0.8,
        #                 s=5)
        #     ax2.set_title('Top View')
        #     ax2.set_xlabel('X')
        #     ax2.set_ylabel('Y')
        #     ax2.axis('equal')
        #
        #     # 侧视图
        #     ax3 = fig.add_subplot(233)
        #     ax3.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 2], c=yaw_angles, cmap='jet',
        #                 alpha=0.8,
        #                 s=5)
        #     ax3.set_title('Side View')
        #     ax3.set_xlabel('X')
        #     ax3.set_ylabel('Z')
        #     ax3.axis('equal')
        #
        #     # 归一化到单位球面上的投影
        #     ax4 = fig.add_subplot(234, projection='3d')
        #     plot_transparent_sphere(ax4, np.max(np.linalg.norm(gt_all, axis=1)))  # GT 视图也加球体背景
        #     ax4.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], gt_all_norm[:, 2],
        #                 c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
        #     ax4.set_title('GT View')
        #     ax4.set_xlabel("X")
        #     ax4.set_ylabel("Y")
        #     ax4.set_zlabel("Z")
        #
        #     # 俯视图
        #     ax5 = fig.add_subplot(235)
        #     ax5.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
        #     ax5.set_title('GT Top View')
        #     ax5.set_xlabel('X')
        #     ax5.set_ylabel('Y')
        #     ax5.axis('equal')
        #
        #     # 侧视图
        #     ax6 = fig.add_subplot(236)
        #     ax6.scatter(gt_all_norm[:, 0], gt_all_norm[:, 2], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
        #     ax6.set_title('GT Side View')
        #     ax6.set_xlabel('X')
        #     ax6.set_ylabel('Z')
        #     ax6.axis('equal')
        #
        #     plt.tight_layout()
        #     plt.show()
        #end Todo

        # end new
        return out_dict_gaze, eye_feature


