#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

import sys

from matplotlib import pyplot as plt
from sklearn.manifold import Isomap

from models.transformer_Single import Transformer

sys.path.append('..')

from models.resnet_encoder import resnet18
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d

class FeatureExtractor(nn.Module):
    def __init__(self,args):
        super(FeatureExtractor, self).__init__()

        self.enc = resnet18(pretrained=args['pretrained_resnet'],
                            **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # 选择自适应池化层，使用平均池化，扁平化输出，输入格式为NCHW
        self.global_pool = SelectAdaptivePool2d(
            pool_type='avg',
            flatten=True,
            input_fmt='NCHW',
        )
            # 如果启用了网络的rend头部或simply头部
        # 设置嵌入维度为256
        embed_dim = 512
        # 定义一个线性层，将输入特征2048降维到embed_dim（256），并且使用偏置
        # self.rend_head_1 = nn.Linear(in_features=2048, out_features=embed_dim, bias=True)

    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        # Move image data to GPU
        x = data_dict['image'].to(torch.float32).to(self.device,
                                                    non_blocking=True)
        B, N, H, W = x.shape
        # Merge batch and window size into one dimension
        x = rearrange(x, 'B N H W -> (B N) 1 H W')
        latent_org, enc_op_org = self.enc(x)
        latent = self.global_pool(latent_org)



        # Un-merge batch and window size
        # TODO Rearange x, latent, enc_op
        eye_enc_out = rearrange(latent, '(B N) C-> B N C', B=B, N=N)

        # eye_3d_out = self.rend_head_1(latent)

        return eye_enc_out, latent

class GazeBranch(nn.Module):
    def __init__(self,args):
        super(GazeBranch, self).__init__()
        embed_dim = 512
        self.n_gpm = 11
        self.gaze_rotate = nn.Linear(
            in_features=embed_dim,
            out_features=self.n_gpm,
            bias=True
        )

    def forward(self, future, projected_features_GT):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        future = future.to(self.device)
        B, N, C = future.shape
        rotate_gaze_para = self.gaze_rotate(rearrange(future, 'B N C -> (B N) C'))
        # rotate_gaze_para = rearrange(rotate_gaze_para, '(B N) C -> B N C', B=B, N=N)
        rotate_gaze_para = torch.tanh(rotate_gaze_para)
        # 构造 gpm_dict（确保张量形状正确）
        gpm_dict = {
            'R': quat_to_matrix(rotate_gaze_para[..., :4]),  # (B*N, 3, 3)
            'O_c': rotate_gaze_para[..., 3:6].reshape(-1, 3),  # (B*N, 3)
            'k1': rotate_gaze_para[..., 6].flatten(),  # (B*N,)
            'k2': rotate_gaze_para[..., 7].flatten(),  # (B*N,)
            'b1': rotate_gaze_para[..., 8].flatten(),  # (B*N,)
            'b2': rotate_gaze_para[..., 9].flatten(),  # (B*N,)
        }

        # 进行 Isomap 降维
        # all_features = np.vstack(future.cpu().detach().numpy())
        # isomap = Isomap(n_neighbors=100, n_components=3)
        # projected_features_Pre = isomap.fit_transform(all_features)
        # projected_features_Pre = torch.tensor(projected_features_Pre, dtype=torch.float32, device='cuda')

        gaze_predict = GeodesicProjectionModule(projected_features_GT, gpm_dict)
        out_dict_gaze = {}
        # normalize the 3D gaze vector
        norm_3D_gaze = gaze_predict / (
                torch.norm(gaze_predict, dim=-1, keepdim=True) + 1e-9)

        # out_dict_gaze['gaze_vector_3D'] = rearrange(norm_3D_gaze, 'B N d -> (B N) d')
        out_dict_gaze['gaze_vector_3D'] = norm_3D_gaze
        return out_dict_gaze

class EyeBallBranch(nn.Module):
    def __init__(self,args):
        super(EyeBallBranch, self).__init__()
        self.N_win = args['frames']
        embed_dim = 512
        # 定义一个Transformer模型
        self.eye_3d = Transformer(
            num_tokens=self.N_win,  # 令牌数量
            embed_dim=embed_dim,  # 嵌入维度
            depth=3,  # Transformer的深度，即层数
            num_heads=8,  # 多头注意力的头数
            mlp_ratio=2.,  # MLP（多层感知机）比率
            qkv_bias=True,  # 是否使用QKV（查询、键、值）偏置
            qk_norm=False,  # 是否对查询和键进行归一化
            init_values=None,  # 初始化值
            pre_norm=False,  # 是否使用预归一化
            fc_norm=True,  # 是否对全连接层进行归一化
            drop_rate=0.,  # Dropout的丢弃率
            pos_drop_rate=0.,  # 位置Dropout的丢弃率
            proj_drop_rate=0.,  # 投影Dropout的丢弃率
            attn_drop_rate=0.,  # 注意力Dropout的丢弃率
            drop_path_rate=0.,  # Drop路径的丢弃率
            weight_init='',  # 权重初始化方法
            norm_layer=None,  # 归一化层
            act_layer=None  # 激活层
        )


        self.n_feat_eye_diff = 7
        self.n_feat_eye_same = 7

        # 定义一个线性层，将embed_dim（256）维度的输入特征映射到不同特征的维度，并且使用偏置
        self.rend_out_diff = nn.Linear(
            in_features=embed_dim,
            out_features=self.n_feat_eye_diff,
            bias=True
        )

        # 定义一个线性层，将embed_dim（256）维度的输入特征映射到相同特征的维度，并且使用偏置
        self.rend_out_same = nn.Linear(
            in_features=embed_dim,
            out_features=self.n_feat_eye_same,
            bias=True
        )

    def forward(self, future):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        out_dict_eye = {}
        B, N, C = future.shape
        eye_3d_out = self.eye_3d(future)
        eye_3d_out_same = self.rend_out_same(eye_3d_out.mean(dim=1)).unsqueeze(1)
        # if self.net_simply_head:
        #     eye_3d_out_same = eye_3d_out_same.expand(-1, N, -1)
        #     eye_3d_out_same = rearrange(eye_3d_out_same, 'B N d -> (B N) d')
        eye_3d_out_diff = self.rend_out_diff(rearrange(eye_3d_out, 'B N C -> (B N) C'))

        eye_3d_out_diff = rearrange(eye_3d_out_diff, '(B N) C -> B N C', B=B, N=N)
        eye_3d_out_same = torch.tanh(eye_3d_out_same)
        eye_3d_out_diff = torch.tanh(eye_3d_out_diff)

        out_dict_eye['L'] = eye_3d_out_same[..., 0]
        out_dict_eye['r_iris'] = eye_3d_out_same[..., 1]
        out_dict_eye['T'] = eye_3d_out_same[..., 2:5]
        out_dict_eye['focal'] = eye_3d_out_same[..., -2:]

        out_dict_eye['R'] = eye_3d_out_diff[..., :3]
        out_dict_eye['r_pupil'] = eye_3d_out_diff[..., 3]
        # normalize the 3D gaze vector
        norm_3D_gaze = eye_3d_out_diff[..., -3:] / (
                    torch.norm(eye_3d_out_diff[..., -3:], dim=-1, keepdim=True) + 1e-9)

        out_dict_eye['gaze_vector_3D'] = rearrange(norm_3D_gaze, 'B N d -> (B N) d')

        return out_dict_eye


def quat_to_matrix(quat):  # quat: (B*N, 4)
    w, x, y, z = quat.unbind(dim=-1)  # 形状均为 (B*N,)

    # 计算旋转矩阵元素（保持形状 (B*N,)）
    xx, yy, zz = x * x, y * y, z * z
    xy, xz, yz = x * y, x * z, y * z
    wx, wy, wz = w * x, w * y, w * z

    # 构造每一行（确保形状为 (B*N, 3)）
    row1 = torch.stack([1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], dim=1)  # dim=1 是关键！
    row2 = torch.stack([2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], dim=1)
    row3 = torch.stack([2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], dim=1)

    # 组合成 (B*N, 3, 3)
    return torch.stack([row1, row2, row3], dim=2)  # 沿第2维度拼接

def GeodesicProjectionModule(future, gpm_dict):
    # 确保所有张量在 GPU 上
    future = future.to('cuda')
    R = gpm_dict['R'].to('cuda')        # 形状应为 (B*N, 3, 3)
    O_c = gpm_dict['O_c'].to('cuda')    # 形状应为 (B*N, 3)
    k1 = gpm_dict['k1'].to('cuda')      # 形状应为 (B*N,)
    k2 = gpm_dict['k2'].to('cuda')      # 形状应为 (B*N,)
    b1 = gpm_dict['b1'].to('cuda')      # 形状应为 (B*N,)
    b2 = gpm_dict['b2'].to('cuda')      # 形状应为 (B*N,)

    f_centered = future.clone()  # 避免原地修改
    f_centered[..., :3] -= O_c  # 仅中心化前3维 (x, y, z)
    # 1. 计算旋转后的特征 e' = R @ (future - O_c)
    e_aligned = torch.matmul(R, f_centered[:, :3].unsqueeze(-1)).squeeze(-1)

    # 2. 计算欧拉角 θ' 和 ψ'
    theta = k1 * torch.atan2(-e_aligned[:, 0], -e_aligned[:, 2]) + b1
    psi = k2 * torch.asin(torch.clamp(-e_aligned[:, 1], -1.0, 1.0)) + b2

    # 3. 计算单位方向向量 [x, y, z]
    x = torch.cos(psi) * torch.sin(theta)
    y = torch.sin(psi)
    z = torch.cos(psi) * torch.cos(theta)

    return torch.stack([x, y, z], dim=1)  # 输出形状 (B*N, 3)


# class GeodesicProjectionModule(nn.Module):
#     """GPM: 解析 PGF Sphere 视线预测"""
#     def __init__(self):
#         super(GeodesicProjectionModule, self).__init__()
#         self.register_parameter("O_c", nn.Parameter(torch.zeros(3)))  # PGF Sphere 的中心
#         self.register_parameter("R", nn.Parameter(torch.eye(3)))  # 旋转矩阵
#         self.k1 = nn.Parameter(torch.tensor(1.0))
#         self.k2 = nn.Parameter(torch.tensor(1.0))
#         self.b1 = nn.Parameter(torch.tensor(0.0))
#         self.b2 = nn.Parameter(torch.tensor(0.0))
#
#     def forward(self, e):
#         # 1. 计算旋转对齐的特征 e'
#         e = torch.tensor(e, dtype=torch.float32, device=self.O_c.device)
#         # 2. 旋转矩阵 R 进行数值稳定化 (保持正交性)
#         with torch.no_grad():
#             U, _, V = torch.svd(self.R)
#             self.R.data = torch.matmul(U, V.T)
#
#         e_aligned = torch.matmul(self.R, (e - self.O_c).T).T
#
#         # 2. 计算欧拉角 (θ', ψ')
#         theta = self.k1 * torch.atan2(-e_aligned[:, 0], -e_aligned[:, 2]) + self.b1
#         psi = self.k2 * torch.asin(torch.clamp(-e_aligned[:, 1], -1.0, 1.0)) + self.b2
#
#
#         # 3. 计算单位方向向量 y'
#         x = torch.cos(psi) * torch.sin(theta)
#         y = torch.sin(psi)
#         z = torch.cos(psi) * torch.cos(theta)
#         # print("R:", self.R)
#         # print("O_c:", self.O_c)
#         # print("e_aligned min/max:", e_aligned.min(), e_aligned.max())
#
#         # if torch.any(e_aligned[:, 1] > 1) or torch.any(e_aligned[:, 1] < -1):
#         #     print("Warning: e_aligned[:, 1] out of bounds:", e_aligned[:, 1])
#
#         return torch.stack([x, y, z], dim=1)

class IsometricPropagator(nn.Module):
    """近似 Isomap：MLP 结构"""
    def __init__(self):
        super(IsometricPropagator, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 3)  # 输出 3D 特征
        )

    def forward(self, x):
        x = torch.tensor(x, dtype=torch.float32, device='cuda')
        return self.mlp(x)

class res_18_gaze(nn.Module):
    def __init__(self,
                 args,
                 act_func=F.leaky_relu,
                 norm=nn.BatchNorm2d):
        super(res_18_gaze, self).__init__()

        self.N_win = args['frames']
        self.net_rend_head = args['net_rend_head']
        self.net_simply_head = args['net_simply_head']
        embed_dim = 512
        assert (self.net_rend_head or self.net_simply_head)
        # self.enc = resnet18(pretrained=args['pretrained_resnet'],
        #                     **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # assert not self.net_ellseg_head  # TODO Enable using the ellseg head (adapt feature pyramid size)
        self.featureExtractor = FeatureExtractor(args)
        # net_dict = torch.load("cur_objs/pretrained/myPara.pt",
        #                       map_location=torch.device('cuda'))
        # state_dict_single = move_to_single(net_dict['state_dict'])
        # self.featureExtractor.load_state_dict(state_dict_single, strict=False)
        # for param in self.featureExtractor.parameters():
        #     param.requires_grad = False
        self.fc = nn.Linear(embed_dim, 3, bias=True)  # 第一个线性层

        self.Ipnet = IsometricPropagator()
        # 选择自适应池化层，使用平均池化，扁平化输出，输入格式为NCHW
        # if self.net_simply_head:
        #     self.gazeBranch = GazeBranch(args)
        # if self.net_rend_head:
        #     self.eyeBallBranch = EyeBallBranch(args)
        # self.gpm = GeodesicProjectionModule()
        self.all_features = []
        self.GT = []



    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        out_dict_gaze = {}
        out_dict_eye = {}
        eye_feature, eye_feature2 = self.featureExtractor(data_dict, args)
        # if self.net_simply_head:
        #     out_dict_gaze = self.gazeBranch(eye_feature, projected_features_GT)
        # if self.net_rend_head:
        #     out_dict_eye = self.eyeBallBranch(eye_feature)

        # new
        gaze_predict = self.fc(eye_feature)
        norm_3D_gaze = gaze_predict / (
                        torch.norm(gaze_predict, dim=-1, keepdim=True) + 1e-9)
        out_dict_gaze['gaze_vector_3D'] = rearrange(norm_3D_gaze, 'B N d -> (B N) d')



        projected_features_Pre = self.Ipnet(eye_feature2)

        # Todo  visual
        batch_size1, frames1, dim1 = eye_feature.shape
        # 重新排列特征为 (batch_size * frames, dim)，适配 Isomap 输入格式
        reshaped_features = rearrange(eye_feature, 'B N d -> (B N) d', B=batch_size1, N=frames1)
        # 进行 Isomap 降维
        isomap = Isomap(n_neighbors=100, n_components=3)
        IP_features2 = isomap.fit_transform(reshaped_features.cpu().detach().numpy())

        # IP_features = rearrange(projected_features_Pre, 'B N d -> (B N) d').detach().cpu().numpy()
        IP_features = projected_features_Pre.detach().cpu().numpy()
        self.all_features.append(IP_features)  # 存入列表
        # projected_features_GT = torch.tensor(IP_features, dtype=torch.float32, device='cuda',
        #                                      requires_grad=True)

        reshaped_gt = data_dict["gaze_vector"].reshape(batch_size1 * frames1, 3)
        self.GT.append(reshaped_gt.cpu().numpy())  # 存入列表
        # self.GT.append(IP_features2)  # 存入列表

        if len(self.all_features) >= 10:
            projected_features = np.vstack(self.all_features)  # 合并所有 batch，形成大样本集

            gt_all = np.vstack(self.GT)  # 合并所有 batch，形成大样本集

            # 计算投影数据的最大范数，作为球体半径
            max_radius = np.max(np.linalg.norm(projected_features, axis=1))

            # 归一化到单位球体
            norms = np.linalg.norm(projected_features, axis=1, keepdims=True)

            # unit_sphere_features = projected_features / norms
            unit_sphere_features = projected_features

            norms_gt = np.linalg.norm(gt_all, axis=1, keepdims=True)
            # gt_all_norm = gt_all / norms_gt
            gt_all_norm = gt_all

            print(unit_sphere_features.shape)
            print(gt_all_norm.shape)

            # 计算GT数据的角度范围
            yaw_angles_gt = np.arctan2(gt_all_norm[:, 1], gt_all_norm[:, 0])
            pitch_angles_gt = np.arcsin(gt_all_norm[:, 2])

            # # 获取GT数据的角度范围
            # gt_yaw_min, gt_yaw_max = np.min(yaw_angles_gt), np.max(yaw_angles_gt)
            # gt_pitch_min, gt_pitch_max = np.min(pitch_angles_gt), np.max(pitch_angles_gt)

            # 计算颜色映射 (Yaw & Pitch 假设)
            yaw_angles = np.arctan2(unit_sphere_features[:, 1], unit_sphere_features[:, 0])  # Yaw 角 (水平)
            pitch_angles = np.arcsin(unit_sphere_features[:, 2])  # Pitch 角 (垂直)

            # # 限制 yaw 和 pitch 在特定范围内
            # yaw_range = np.logical_and(yaw_angles >= np.deg2rad(50),
            #                            yaw_angles <= np.deg2rad(150))  # Yaw 在 [-50°, 50°]
            # pitch_range = np.logical_and(pitch_angles >= np.deg2rad(-50),
            #                              pitch_angles <= np.deg2rad(50))  # Pitch 在 [-50°, 50°]
            # # 根据 Yaw 和 Pitch 的范围限制数据
            # # unit_sphere_features = unit_sphere_features[yaw_range]
            # # unit_sphere_features_pitch = unit_sphere_features[pitch_range]
            #
            # # yaw_angles = yaw_angles[yaw_range]  # 颜色映射
            # # pitch_angles = pitch_angles[pitch_range]
            # 根据GT的角度范围筛选预测数据
            # valid_range = np.logical_and(yaw_angles >= gt_yaw_min, yaw_angles <= gt_yaw_max)

            # unit_sphere_features = unit_sphere_features[valid_range]
            # yaw_angles = yaw_angles[valid_range]
            # pitch_angles = pitch_angles[valid_range]
            # 绘制可视化
            fig = plt.figure(figsize=(14, 8))

            # 归一化到单位球面上的投影
            ax1 = fig.add_subplot(231, projection='3d')
            plot_transparent_sphere(ax1, max_radius)  # 添加半透明球体背景
            ax1.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], unit_sphere_features[:, 2],
                        c=yaw_angles, cmap='jet', alpha=0.8, s=5)
            # # 添加切割平面，模拟图中的效果
            # # 假设平面是通过z轴来切割的，设置切割平面
            # plane_x = np.linspace(-1, 1, 10)
            # plane_y = np.linspace(-1, 1, 10)
            # X, Y = np.meshgrid(plane_x, plane_y)
            # Z = 0.5 * np.ones_like(X)  # 设置平面高度，模拟切割效果
            # ax1.plot_surface(X, Y, Z, color='gray', alpha=0.5)
            ax1.set_title('3D View')
            ax1.set_xlabel("X")
            ax1.set_ylabel("Y")
            ax1.set_zlabel("Z")

            # 俯视图
            ax2 = fig.add_subplot(232)
            ax2.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], c=yaw_angles, cmap='jet',
                        alpha=0.8,
                        s=5)
            ax2.set_title('Top View')
            ax2.set_xlabel('X')
            ax2.set_ylabel('Y')
            ax2.axis('equal')

            # 侧视图
            ax3 = fig.add_subplot(233)
            ax3.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 2], c=yaw_angles, cmap='jet',
                        alpha=0.8,
                        s=5)
            ax3.set_title('Side View')
            ax3.set_xlabel('X')
            ax3.set_ylabel('Z')
            ax3.axis('equal')

            # 归一化到单位球面上的投影
            ax4 = fig.add_subplot(234, projection='3d')
            plot_transparent_sphere(ax4, np.max(np.linalg.norm(gt_all, axis=1)))  # GT 视图也加球体背景
            ax4.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], gt_all_norm[:, 2],
                        c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
            ax4.set_title('GT View')
            ax4.set_xlabel("X")
            ax4.set_ylabel("Y")
            ax4.set_zlabel("Z")

            # 俯视图
            ax5 = fig.add_subplot(235)
            ax5.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
            ax5.set_title('GT Top View')
            ax5.set_xlabel('X')
            ax5.set_ylabel('Y')
            ax5.axis('equal')

            # 侧视图
            ax6 = fig.add_subplot(236)
            ax6.scatter(gt_all_norm[:, 0], gt_all_norm[:, 2], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
            ax6.set_title('GT Side View')
            ax6.set_xlabel('X')
            ax6.set_ylabel('Z')
            ax6.axis('equal')

            plt.tight_layout()
            plt.show()


        # end new
        return out_dict_gaze, out_dict_eye


