#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

import sys

from matplotlib import pyplot as plt
from sklearn.manifold import Isomap

from models.transformer_Single import Transformer

sys.path.append('..')

from models.resnet_encoder import resnet18
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d

class FeatureExtractor(nn.Module):
    def __init__(self,args):
        super(FeatureExtractor, self).__init__()

        self.enc = resnet18(pretrained=args['pretrained_resnet'],
                            **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # 选择自适应池化层，使用平均池化，扁平化输出，输入格式为NCHW
        self.global_pool = SelectAdaptivePool2d(
            pool_type='avg',
            flatten=True,
            input_fmt='NCHW',
        )
            # 如果启用了网络的rend头部或simply头部
        # 设置嵌入维度为256
        embed_dim = 512
        # 定义一个线性层，将输入特征2048降维到embed_dim（256），并且使用偏置
        # self.rend_head_1 = nn.Linear(in_features=2048, out_features=embed_dim, bias=True)

    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        # Move image data to GPU
        x = data_dict['image'].to(torch.float32).to(self.device,
                                                    non_blocking=True)
        B, N, H, W = x.shape
        # Merge batch and window size into one dimension
        x = rearrange(x, 'B N H W -> (B N) 1 H W')
        latent_org, enc_op_org = self.enc(x)
        latent = self.global_pool(latent_org)

        # Un-merge batch and window size
        # TODO Rearange x, latent, enc_op
        eye_enc_out = rearrange(latent, '(B N) C-> B N C', B=B, N=N)

        # eye_3d_out = self.rend_head_1(latent)

        return eye_enc_out, latent



class EyeBallBranch(nn.Module):
    def __init__(self,args):
        super(EyeBallBranch, self).__init__()
        self.N_win = args['frames']
        embed_dim = 512
        # 定义一个Transformer模型
        self.eye_3d = Transformer(
            num_tokens=self.N_win,  # 令牌数量
            embed_dim=embed_dim,  # 嵌入维度
            depth=3,  # Transformer的深度，即层数
            num_heads=8,  # 多头注意力的头数
            mlp_ratio=2.,  # MLP（多层感知机）比率
            qkv_bias=True,  # 是否使用QKV（查询、键、值）偏置
            qk_norm=False,  # 是否对查询和键进行归一化
            init_values=None,  # 初始化值
            pre_norm=False,  # 是否使用预归一化
            fc_norm=True,  # 是否对全连接层进行归一化
            drop_rate=0.,  # Dropout的丢弃率
            pos_drop_rate=0.,  # 位置Dropout的丢弃率
            proj_drop_rate=0.,  # 投影Dropout的丢弃率
            attn_drop_rate=0.,  # 注意力Dropout的丢弃率
            drop_path_rate=0.,  # Drop路径的丢弃率
            weight_init='',  # 权重初始化方法
            norm_layer=None,  # 归一化层
            act_layer=None  # 激活层
        )


        self.n_feat_eye_diff = 7
        self.n_feat_eye_same = 7

        # 定义一个线性层，将embed_dim（256）维度的输入特征映射到不同特征的维度，并且使用偏置
        self.rend_out_diff = nn.Linear(
            in_features=embed_dim,
            out_features=self.n_feat_eye_diff,
            bias=True
        )

        # 定义一个线性层，将embed_dim（256）维度的输入特征映射到相同特征的维度，并且使用偏置
        self.rend_out_same = nn.Linear(
            in_features=embed_dim,
            out_features=self.n_feat_eye_same,
            bias=True
        )

    def forward(self, future):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        out_dict_eye = {}
        B, N, C = future.shape
        eye_3d_out = self.eye_3d(future)
        eye_3d_out_same = self.rend_out_same(eye_3d_out.mean(dim=1)).unsqueeze(1)
        # if self.net_simply_head:
        #     eye_3d_out_same = eye_3d_out_same.expand(-1, N, -1)
        #     eye_3d_out_same = rearrange(eye_3d_out_same, 'B N d -> (B N) d')
        eye_3d_out_diff = self.rend_out_diff(rearrange(eye_3d_out, 'B N C -> (B N) C'))

        eye_3d_out_diff = rearrange(eye_3d_out_diff, '(B N) C -> B N C', B=B, N=N)
        eye_3d_out_same = torch.tanh(eye_3d_out_same)
        eye_3d_out_diff = torch.tanh(eye_3d_out_diff)

        out_dict_eye['L'] = eye_3d_out_same[..., 0]
        out_dict_eye['r_iris'] = eye_3d_out_same[..., 1]
        out_dict_eye['T'] = eye_3d_out_same[..., 2:5]
        out_dict_eye['focal'] = eye_3d_out_same[..., -2:]

        out_dict_eye['R'] = eye_3d_out_diff[..., :3]
        out_dict_eye['r_pupil'] = eye_3d_out_diff[..., 3]
        # normalize the 3D gaze vector
        norm_3D_gaze = eye_3d_out_diff[..., -3:] / (
                    torch.norm(eye_3d_out_diff[..., -3:], dim=-1, keepdim=True) + 1e-9)

        out_dict_eye['gaze_vector_3D'] = rearrange(norm_3D_gaze, 'B N d -> (B N) d')

        return out_dict_eye





class res_18_eye(nn.Module):
    def __init__(self,
                 args,
                 act_func=F.leaky_relu,
                 norm=nn.BatchNorm2d):
        super(res_18_eye, self).__init__()

        self.N_win = args['frames']
        self.net_rend_head = args['net_rend_head']
        self.net_simply_head = args['net_simply_head']
        embed_dim = 512
        assert (self.net_rend_head or self.net_simply_head)
        # self.enc = resnet18(pretrained=args['pretrained_resnet'],
        #                     **{'in_chans': 1})  # TODO enable parametrizing pretrained weights
        # assert not self.net_ellseg_head  # TODO Enable using the ellseg head (adapt feature pyramid size)
        self.featureExtractor = FeatureExtractor(args)
        # net_dict = torch.load("cur_objs/pretrained/myPara.pt",
        #                       map_location=torch.device('cuda'))
        # state_dict_single = move_to_single(net_dict['state_dict'])
        # self.featureExtractor.load_state_dict(state_dict_single, strict=False)
        # for param in self.featureExtractor.parameters():
        #     param.requires_grad = False

        # 选择自适应池化层，使用平均池化，扁平化输出，输入格式为NCHW
        # if self.net_simply_head:
        #     self.gazeBranch = GazeBranch(args)
        if self.net_rend_head:
            self.eyeBallBranch = EyeBallBranch(args)
        # self.gpm = GeodesicProjectionModule()
        self.all_features = []
        self.GT = []



    def forward(self, data_dict, args):
        # Cast to float32, move to device and add a dummy color channel
        if 'device' not in self.__dict__:
            self.device = next(self.parameters()).device
        # 将所有层移动到设备
        self.to(self.device)
        out_dict_gaze = {}
        out_dict_eye = {}
        eye_feature, eye_feature2 = self.featureExtractor(data_dict, args)
        # if self.net_simply_head:
        #     out_dict_gaze = self.gazeBranch(eye_feature, projected_features_GT)
        if self.net_rend_head:
            out_dict_eye = self.eyeBallBranch(eye_feature)


        # Todo  visual
        # batch_size1, frames1, dim1 = eye_feature.shape
        # # 重新排列特征为 (batch_size * frames, dim)，适配 Isomap 输入格式
        # reshaped_features = rearrange(eye_feature, 'B N d -> (B N) d', B=batch_size1, N=frames1)
        # # 进行 Isomap 降维
        # isomap = Isomap(n_neighbors=100, n_components=3)
        # IP_features2 = isomap.fit_transform(reshaped_features.cpu().detach().numpy())
        #
        # # IP_features = rearrange(projected_features_Pre, 'B N d -> (B N) d').detach().cpu().numpy()
        # IP_features = projected_features_Pre.detach().cpu().numpy()
        # self.all_features.append(IP_features)  # 存入列表
        # # projected_features_GT = torch.tensor(IP_features, dtype=torch.float32, device='cuda',
        # #                                      requires_grad=True)
        #
        # reshaped_gt = data_dict["gaze_vector"].reshape(batch_size1 * frames1, 3)
        # self.GT.append(reshaped_gt.cpu().numpy())  # 存入列表
        # # self.GT.append(IP_features2)  # 存入列表
        #
        # if len(self.all_features) >= 10:
        #     projected_features = np.vstack(self.all_features)  # 合并所有 batch，形成大样本集
        #
        #     gt_all = np.vstack(self.GT)  # 合并所有 batch，形成大样本集
        #
        #     # 计算投影数据的最大范数，作为球体半径
        #     max_radius = np.max(np.linalg.norm(projected_features, axis=1))
        #
        #     # 归一化到单位球体
        #     norms = np.linalg.norm(projected_features, axis=1, keepdims=True)
        #
        #     # unit_sphere_features = projected_features / norms
        #     unit_sphere_features = projected_features
        #
        #     norms_gt = np.linalg.norm(gt_all, axis=1, keepdims=True)
        #     # gt_all_norm = gt_all / norms_gt
        #     gt_all_norm = gt_all
        #
        #     print(unit_sphere_features.shape)
        #     print(gt_all_norm.shape)
        #
        #     # 计算GT数据的角度范围
        #     yaw_angles_gt = np.arctan2(gt_all_norm[:, 1], gt_all_norm[:, 0])
        #     pitch_angles_gt = np.arcsin(gt_all_norm[:, 2])
        #
        #     # # 获取GT数据的角度范围
        #     # gt_yaw_min, gt_yaw_max = np.min(yaw_angles_gt), np.max(yaw_angles_gt)
        #     # gt_pitch_min, gt_pitch_max = np.min(pitch_angles_gt), np.max(pitch_angles_gt)
        #
        #     # 计算颜色映射 (Yaw & Pitch 假设)
        #     yaw_angles = np.arctan2(unit_sphere_features[:, 1], unit_sphere_features[:, 0])  # Yaw 角 (水平)
        #     pitch_angles = np.arcsin(unit_sphere_features[:, 2])  # Pitch 角 (垂直)
        #
        #     # # 限制 yaw 和 pitch 在特定范围内
        #     # yaw_range = np.logical_and(yaw_angles >= np.deg2rad(50),
        #     #                            yaw_angles <= np.deg2rad(150))  # Yaw 在 [-50°, 50°]
        #     # pitch_range = np.logical_and(pitch_angles >= np.deg2rad(-50),
        #     #                              pitch_angles <= np.deg2rad(50))  # Pitch 在 [-50°, 50°]
        #     # # 根据 Yaw 和 Pitch 的范围限制数据
        #     # # unit_sphere_features = unit_sphere_features[yaw_range]
        #     # # unit_sphere_features_pitch = unit_sphere_features[pitch_range]
        #     #
        #     # # yaw_angles = yaw_angles[yaw_range]  # 颜色映射
        #     # # pitch_angles = pitch_angles[pitch_range]
        #     # 根据GT的角度范围筛选预测数据
        #     # valid_range = np.logical_and(yaw_angles >= gt_yaw_min, yaw_angles <= gt_yaw_max)
        #
        #     # unit_sphere_features = unit_sphere_features[valid_range]
        #     # yaw_angles = yaw_angles[valid_range]
        #     # pitch_angles = pitch_angles[valid_range]
        #     # 绘制可视化
        #     fig = plt.figure(figsize=(14, 8))
        #
        #     # 归一化到单位球面上的投影
        #     ax1 = fig.add_subplot(231, projection='3d')
        #     plot_transparent_sphere(ax1, max_radius)  # 添加半透明球体背景
        #     ax1.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], unit_sphere_features[:, 2],
        #                 c=yaw_angles, cmap='jet', alpha=0.8, s=5)
        #     # # 添加切割平面，模拟图中的效果
        #     # # 假设平面是通过z轴来切割的，设置切割平面
        #     # plane_x = np.linspace(-1, 1, 10)
        #     # plane_y = np.linspace(-1, 1, 10)
        #     # X, Y = np.meshgrid(plane_x, plane_y)
        #     # Z = 0.5 * np.ones_like(X)  # 设置平面高度，模拟切割效果
        #     # ax1.plot_surface(X, Y, Z, color='gray', alpha=0.5)
        #     ax1.set_title('3D View')
        #     ax1.set_xlabel("X")
        #     ax1.set_ylabel("Y")
        #     ax1.set_zlabel("Z")
        #
        #     # 俯视图
        #     ax2 = fig.add_subplot(232)
        #     ax2.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 1], c=yaw_angles, cmap='jet',
        #                 alpha=0.8,
        #                 s=5)
        #     ax2.set_title('Top View')
        #     ax2.set_xlabel('X')
        #     ax2.set_ylabel('Y')
        #     ax2.axis('equal')
        #
        #     # 侧视图
        #     ax3 = fig.add_subplot(233)
        #     ax3.scatter(unit_sphere_features[:, 0], unit_sphere_features[:, 2], c=yaw_angles, cmap='jet',
        #                 alpha=0.8,
        #                 s=5)
        #     ax3.set_title('Side View')
        #     ax3.set_xlabel('X')
        #     ax3.set_ylabel('Z')
        #     ax3.axis('equal')
        #
        #     # 归一化到单位球面上的投影
        #     ax4 = fig.add_subplot(234, projection='3d')
        #     plot_transparent_sphere(ax4, np.max(np.linalg.norm(gt_all, axis=1)))  # GT 视图也加球体背景
        #     ax4.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], gt_all_norm[:, 2],
        #                 c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
        #     ax4.set_title('GT View')
        #     ax4.set_xlabel("X")
        #     ax4.set_ylabel("Y")
        #     ax4.set_zlabel("Z")
        #
        #     # 俯视图
        #     ax5 = fig.add_subplot(235)
        #     ax5.scatter(gt_all_norm[:, 0], gt_all_norm[:, 1], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
        #     ax5.set_title('GT Top View')
        #     ax5.set_xlabel('X')
        #     ax5.set_ylabel('Y')
        #     ax5.axis('equal')
        #
        #     # 侧视图
        #     ax6 = fig.add_subplot(236)
        #     ax6.scatter(gt_all_norm[:, 0], gt_all_norm[:, 2], c=yaw_angles_gt, cmap='jet', alpha=0.8, s=5)
        #     ax6.set_title('GT Side View')
        #     ax6.set_xlabel('X')
        #     ax6.set_ylabel('Z')
        #     ax6.axis('equal')
        #
        #     plt.tight_layout()
        #     plt.show()


        # end new
        return out_dict_gaze, out_dict_eye


