import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

import sys

from matplotlib import pyplot as plt
from sklearn.manifold import Isomap

sys.path.append('..')

from models.resnet_encoder import resnet18
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d

class FeatureExtractor(nn.Module):
    def __init__(self,args):
        super(FeatureExtractor, self).__init__()

        self.enc = resnet18(pretrained=args['pretrained_resnet'],
                            **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # 选择自适应池化层，使用平均池化，扁平化输出，输入格式为NCHW
        self.global_pool = SelectAdaptivePool2d(
            pool_type='avg',
            flatten=True,
            input_fmt='NCHW',
        )
            # 如果启用了网络的rend头部或simply头部
        # 设置嵌入维度为256
        embed_dim = 512
        # 定义一个线性层，将输入特征2048降维到embed_dim（256），并且使用偏置
        # self.rend_head_1 = nn.Linear(in_features=2048, out_features=embed_dim, bias=True)

    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        # Move image data to GPU
        x = data_dict['image'].to(torch.float32).to(self.device,
                                                    non_blocking=True)
        B, N, H, W = x.shape
        # Merge batch and window size into one dimension
        x = rearrange(x, 'B N H W -> (B N) 1 H W')
        latent_org, enc_op_org = self.enc(x)
        latent = self.global_pool(latent_org)

        # Un-merge batch and window size
        # TODO Rearange x, latent, enc_op
        eye_enc_out = rearrange(latent, '(B N) C-> B N C', B=B, N=N)

        # eye_3d_out = self.rend_head_1(latent)

        return eye_enc_out, latent

# 随机旋转矩阵生成函数
def random_rotation_matrix(max_angle=30):
    """生成绕随机轴的旋转矩阵，角度范围 ±max_angle (度)"""
    # 随机旋转轴
    axis = np.random.randn(3)
    axis = axis / np.linalg.norm(axis)
    # 随机旋转角度
    theta = np.radians(np.random.uniform(-max_angle, max_angle))

    # Rodrigues' rotation formula
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K)
    return R

class res_18_gaze(nn.Module):
    def __init__(self,
                 args,
                 act_func=F.leaky_relu,
                 norm=nn.BatchNorm2d):
        super(res_18_gaze, self).__init__()

        self.N_win = args['frames']
        self.net_rend_head = args['net_rend_head']
        self.net_simply_head = args['net_simply_head']
        embed_dim = 512
        assert (self.net_rend_head or self.net_simply_head)
        # self.enc = resnet18(pretrained=args['pretrained_resnet'],
        #                     **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # assert not self.net_ellseg_head  # TODO Enable using the ellseg head (adapt feature pyramid size)
        self.featureExtractor = FeatureExtractor(args)
        # net_dict = torch.load("cur_objs/pretrained/myPara.pt",
        #                       map_location=torch.device('cuda'))
        # state_dict_single = move_to_single(net_dict['state_dict'])
        # self.featureExtractor.load_state_dict(state_dict_single, strict=False)
        # for param in self.featureExtractor.parameters():
        #     param.requires_grad = False
        self.fc = nn.Linear(embed_dim, 3, bias=True)  # 第一个线性层

        self.all_features = []
        self.GT = []

        self.yaw_list = []
        self.pitch_list = []



    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        out_dict_gaze = {}
        eye_feature, eye_feature2 = self.featureExtractor(data_dict, args)
        # new
        gaze_predict = self.fc(eye_feature)
        norm_3D_gaze = gaze_predict / (
                        torch.norm(gaze_predict, dim=-1, keepdim=True) + 1e-9)
        out_dict_gaze['gaze_vector_3D'] = rearrange(norm_3D_gaze, 'B N d -> (B N) d')


        # Todo  visual
        batch_size1, frames1, dim1 = eye_feature.shape
        # 重新排列特征为 (batch_size * frames, dim)，适配 Isomap 输入格式
        reshaped_features = rearrange(eye_feature, 'B N d -> (B N) d', B=batch_size1, N=frames1)
        # 进行 Isomap 降维
        isomap = Isomap(n_neighbors=100, n_components=3)
        IP_features2 = isomap.fit_transform(reshaped_features.cpu().detach().numpy())

        IP_features = rearrange(norm_3D_gaze, 'B N d -> (B N) d').detach().cpu().numpy()
        self.all_features.append(IP_features2)  # 存入列表
        # projected_features_GT = torch.tensor(IP_features, dtype=torch.float32, device='cuda',
        #                                      requires_grad=True)
        # 存储 yaw / pitch

        reshaped_gt = data_dict["gaze_vector"].reshape(batch_size1 * frames1, 3)
        for vec in reshaped_gt:
            x, y, z = vec

            R = random_rotation_matrix(max_angle=50)  # 控制旋转幅度
            vec_rot = R @ np.array([x, y, z])
            x, y, z = vec_rot

            # 转换成 yaw / pitch
            yaw = np.degrees(np.arctan2(x, z))
            pitch = np.degrees(np.arctan2(-y, np.sqrt(x ** 2 + z ** 2)))

            self.yaw_list.append(yaw)
            self.pitch_list.append(pitch)

        # self.GT.append(reshaped_gt.cpu().numpy())  # 存入列表
        # self.GT.append(IP_features2)  # 存入列表

        if len(self.yaw_list) >= 10000:
            # 绘制散点
            plt.figure(figsize=(8, 8))
            plt.scatter(self.yaw_list, self.pitch_list, s=15, c='tab:red', alpha=0.5, label="TEyeD_TP1")

            # 设置坐标轴
            plt.xlabel("Yaw (degrees)", fontsize=20)
            plt.ylabel("Pitch (degrees)", fontsize=20)
            plt.xlim([-90, 90])
            plt.ylim([-80, 80])
            plt.xticks(fontsize=14)
            plt.yticks(fontsize=14)
            plt.grid(True, linestyle="--", alpha=1.0, color="gray")
            plt.legend(fontsize=20)
            plt.title("Gaze Distributions", fontsize=20)

            plt.show()
            # projected_features = np.vstack(self.all_features)  # 合并所有 batch，形成大样本集
            #
            # gt_all = np.vstack(self.GT)  # 合并所有 batch，形成大样本集
            #
            # # 计算投影数据的最大范数，作为球体半径
            # max_radius = np.max(np.linalg.norm(projected_features, axis=1))
            #
            # # 归一化到单位球体
            # norms = np.linalg.norm(projected_features, axis=1, keepdims=True)
            #
            # # unit_sphere_features = projected_features / norms
            # unit_sphere_features = projected_features
            #
            # norms_gt = np.linalg.norm(gt_all, axis=1, keepdims=True)
            # # gt_all_norm = gt_all / norms_gt
            # gt_all_norm = gt_all
            #
            # print(unit_sphere_features.shape)
            # print(gt_all_norm.shape)
            #
            # # 计算GT数据的角度范围
            # yaw_angles_gt = np.arctan2(gt_all_norm[:, 1], gt_all_norm[:, 0])
            # pitch_angles_gt = np.arcsin(gt_all_norm[:, 2])
            #
            # # # 获取GT数据的角度范围
            # # gt_yaw_min, gt_yaw_max = np.min(yaw_angles_gt), np.max(yaw_angles_gt)
            # # gt_pitch_min, gt_pitch_max = np.min(pitch_angles_gt), np.max(pitch_angles_gt)
            #
            # # 计算颜色映射 (Yaw & Pitch 假设)
            # yaw_angles = np.arctan2(unit_sphere_features[:, 1], unit_sphere_features[:, 0])  # Yaw 角 (水平)
            # pitch_angles = np.arcsin(unit_sphere_features[:, 2])  # Pitch 角 (垂直)
            #
            # # # 限制 yaw 和 pitch 在特定范围内
            # # yaw_range = np.logical_and(yaw_angles >= np.deg2rad(50),
            # #                            yaw_angles <= np.deg2rad(150))  # Yaw 在 [-50°, 50°]
            # # pitch_range = np.logical_and(pitch_angles >= np.deg2rad(-50),
            # #                              pitch_angles <= np.deg2rad(50))  # Pitch 在 [-50°, 50°]
            # # # 根据 Yaw 和 Pitch 的范围限制数据
            # # # unit_sphere_features = unit_sphere_features[yaw_range]
            # # # unit_sphere_features_pitch = unit_sphere_features[pitch_range]
            # #
            # # # yaw_angles = yaw_angles[yaw_range]  # 颜色映射
            # # # pitch_angles = pitch_angles[pitch_range]
            # # 根据GT的角度范围筛选预测数据
            # # valid_range = np.logical_and(yaw_angles >= gt_yaw_min, yaw_angles <= gt_yaw_max)
            #
            # # unit_sphere_features = unit_sphere_features[valid_range]
            # # yaw_angles = yaw_angles[valid_range]
            # # pitch_angles = pitch_angles[valid_range]
            # # 绘制可视化
            # fig = plt.figure(figsize=(14, 8))
            #
            # # 归一化到单位球面上的投影
            # ax1 = fig.add_subplot(231, projection='3d')
            # plot_transparent_sphere(ax1, max_radius)  # 添加半透明球体背景
            # ax1.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], unit_sphere_features[:, 2],
            #             c=yaw_angles, cmap='jet', alpha=0.8, s=5)
            # # # 添加切割平面，模拟图中的效果
            # # # 假设平面是通过z轴来切割的，设置切割平面
            # # plane_x = np.linspace(-1, 1, 10)
            # # plane_y = np.linspace(-1, 1, 10)
            # # X, Y = np.meshgrid(plane_x, plane_y)
            # # Z = 0.5 * np.ones_like(X)  # 设置平面高度，模拟切割效果
            # # ax1.plot_surface(X, Y, Z, color='gray', alpha=0.5)
            # ax1.set_title('3D View')
            # ax1.set_xlabel("X")
            # ax1.set_ylabel("Y")
            # ax1.set_zlabel("Z")
            #
            # # 俯视图
            # ax2 = fig.add_subplot(232)
            # ax2.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], c=yaw_angles, cmap='jet',
            #             alpha=0.8,
            #             s=5)
            # ax2.set_title('Top View')
            # ax2.set_xlabel('X')
            # ax2.set_ylabel('Y')
            # ax2.axis('equal')
            #
            # # 侧视图
            # ax3 = fig.add_subplot(233)
            # ax3.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 2], c=yaw_angles, cmap='jet',
            #             alpha=0.8,
            #             s=5)
            # ax3.set_title('Side View')
            # ax3.set_xlabel('X')
            # ax3.set_ylabel('Z')
            # ax3.axis('equal')
            #
            # # 归一化到单位球面上的投影
            # ax4 = fig.add_subplot(234, projection='3d')
            # plot_transparent_sphere(ax4, np.max(np.linalg.norm(gt_all, axis=1)))  # GT 视图也加球体背景
            # ax4.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], gt_all_norm[:, 2],
            #             c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
            # ax4.set_title('GT View')
            # ax4.set_xlabel("X")
            # ax4.set_ylabel("Y")
            # ax4.set_zlabel("Z")
            #
            # # 俯视图
            # ax5 = fig.add_subplot(235)
            # ax5.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
            # ax5.set_title('GT Top View')
            # ax5.set_xlabel('X')
            # ax5.set_ylabel('Y')
            # ax5.axis('equal')
            #
            # # 侧视图
            # ax6 = fig.add_subplot(236)
            # ax6.scatter(gt_all_norm[:, 0], gt_all_norm[:, 2], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
            # ax6.set_title('GT Side View')
            # ax6.set_xlabel('X')
            # ax6.set_ylabel('Z')
            # ax6.axis('equal')
            #
            # plt.tight_layout()
            # plt.show()
        #end Todo

        # end new
        return out_dict_gaze, eye_feature


def plot_transparent_sphere(ax, radius=1.0, alpha=0.2):
    """ 在3D坐标系中绘制半透明球体 """
    u = np.linspace(0, 2 * np.pi, 30)
    v = np.linspace(0, np.pi, 30)
    x = radius * np.outer(np.cos(u), np.sin(v))
    y = radius * np.outer(np.sin(u), np.sin(v))
    z = radius * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x, y, z, color='gray', alpha=alpha, edgecolors='none')