# -*- coding: UTF-8 -*-
'''
@Project ：PD_Gaze 
@File    ：GPModule.py
@Author  ：xyf
@Date    ：2025/6/10 19:22 
'''
import gc
import json
import logging
import math
import pickle
import sys
import time

from einops import rearrange
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from helperfunctions.CurriculumLib import DataLoader_riteyes
from helperfunctions.hfunctions import create_experiment_folder_tree, fix_batch, merge_two_dicts
import helperfunctions.CurriculumLib as CurLib
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from pprint import pprint
import helperfunctions.DataProcessFuncs as funcs
from models.SphereAlignment import ISOMap, FitGaze
from rendering.rendered_semantics_loss import SobelFilter, loss_fn_rend_sprvs
from scripts import reshape_gt, get_metrics_simple, move_gpu, detach_cpu_numpy, aggregate_metrics, log_wandb, \
    angular_error
from timm.optim import Lamb as Lamb_timm
from timm.scheduler import CosineLRScheduler as CosineLRScheduler_timm
import torch
import random
import warnings
import numpy as np
import wandb
import torch.nn.functional as F
from args_maker import make_args
from helperfunctions.utils import move_to_single, make_logger, get_nparams, EarlyStopping, SpikeDetection
from models.models_mux import model_dict

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["WANDB_API_KEY"] = '+++++++++++'  # 将引号内的+替换成自己在wandb上的一串值
os.environ["WANDB_MODE"] = "offline"  # 离线  （此行代码不用修改）

class GPModuleTrainer:
    def __init__(self, args):
        self.args = args
        self.model_name = args['model']
        self.cur_objs = args['cur_obj']
        self.test_objs = args['test_obj']
        print(f'[Trainer Log] Model Name: \033[0;32;40m\t{self.model_name}\033[0m')
        # 检查是否有可用的 GPU
        if torch.cuda.is_available():
            self.device = torch.device('cuda')  # 使用 CUDA 设备
            self.GPU_num = torch.cuda.device_count()  # 获取 GPU 数量
            print(f'[Trainer Log] \033[0;32;40m\t{self.GPU_num} GPUs Detected \033[0m')
        else:
            self.device = torch.device('cpu')  # 使用 CPU
            self.GPU_num = 0
            print(f'[Trainer Log] \033[0;32;40m\tNo GPU Detected. Run on CPU \033[0m')

        # self.save_pt_name = f'{self.model_name}_{self.cur_objs}_myPara.pt'

        self.path_dict, self.exp_name_str = create_experiment_folder_tree(args['repo_root'],
                                                                args['path_exp_tree'],
                                                                args['exp_name'],
                                                                args['only_test'],
                                                                create_tree=args['local_rank'] == 0 if args[
                                                                    'do_distributed'] else True)
        if 'DEBUG' not in args['exp_name']:
            wandb.init(project="GazeModel",
                       entity='xyfdxb',
                       config=args, name=self.exp_name_str)


    def run(self):
        self.path_dict['repo_root'] = args['repo_root']
        self.path_dict['path_data_source'] = args['path_data_source']
        self.path_dict['path_data_target'] = args['path_data_target']
        self.source_name = f'{self.cur_objs}--{self.cur_objs}'
        self.target_name = f'{self.cur_objs}--{self.test_objs}'
        # 真值 保存地址
        self.train_gt_path = os.path.join(self.args['pathResult'], 'Baseline', f'{self.source_name}[Epoch00_GT].json')
        self.test_gt_path = os.path.join(self.args['pathResult'], 'Baseline', f'{self.target_name}[Epoch00_GT].json')
        # 高维特征 保存地址
        self.train_feature_path = os.path.join(self.args['pathResult'], 'Baseline', f'{self.source_name}[Epoch00_Feature].npy')
        self.test_feature_path = os.path.join(self.args['pathResult'], 'Baseline', f'{self.target_name}[Epoch00_Feature].npy')
        # %%
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True

        # Set seeds
        torch.manual_seed(args['seed'])
        np.random.seed(args['seed'])
        random.seed(args['seed'])

        print('---------')
        print('解析的参数')
        pprint(args)  # 打印解析后的参数
        print('---------')
        # %% Initialize logger
        logger = make_logger(self.path_dict['logs'] + '/train_log.log',
                             rank=args['local_rank'] if args['do_distributed'] else 0)
        for epoch in tqdm(range(args['epochs']), desc=f"Epochs", leave=False):
            # Isomap 保存地址
            self.ISO_fitter_path = os.path.join(self.path_dict['gpm'], f'[TrainSet][{self.source_name}][Epoch{str(epoch).zfill(2)}][ISO]')
            # GPM 参数保存地址
            self.GPM_path = os.path.join(self.path_dict['gpm'], f'[TrainSet][{self.source_name}][Epoch{str(epoch).zfill(2)}][GPM].npy')
            # PGF 特征保存地址
            self.train_PGF_path = os.path.join(self.path_dict['gpm'], f'[TrainSet][{self.source_name}][Epoch{str(epoch).zfill(2)}][PGF].npy')
            self.test_PGF_path = os.path.join(self.path_dict['gpm'], f'[Evaluation][{self.target_name}][Epoch{str(epoch).zfill(2)}][PGF].npy')
            # 打印日志 保存地址
            self.test_log_path =os.path.join(self.path_dict['gpm'], f'[Evaluation][{self.target_name}][Epoch{str(epoch).zfill(2)}][GPM_test_log].json')
            logger.write(f'-------[TrainSet][Epoch {str(epoch)}-------')
            # Train and save validated model
            if not args['only_test']:
                if not args['only_valid']:
                    print('train mode')
                    self.train(args, self.path_dict, logger)

                    print('test mode')
                    # 如果测试阶段未训练，加载已有的 Isomap 降维器和 GPM 参数
                    with open(self.ISO_fitter_path, 'rb') as f:
                        self.ISO_fitter = pickle.load(f)
                    print(f'[Baseline GPM] Isomap Fitter Found and Load: {self.ISO_fitter_path}')
                    self.GPM_param = np.load(self.GPM_path)
                    self.test(args, self.path_dict, logger)

            elif args['only_test']:
                print('test mode')
                # 如果测试阶段未训练，加载已有的 Isomap 降维器和 GPM 参数
                with open(self.ISO_fitter_path, 'rb') as f:
                    self.ISO_fitter = pickle.load(f)
                print(f'[Baseline GPM] Isomap Fitter Found and Load: {self.ISO_fitter_path}')
                self.GPM_param = np.load(self.GPM_path)
                self.test(args, self.path_dict, logger)

        print("run done!")
        wandb.finish()


    def train(self, args, path_dict, logger):
        """
            执行 GPM 训练流程：
            1. 加载训练特征
            2. 使用 Isomap 进行降维 → PGF
            3. 用 PGF 和真实 gaze label 拟合 GPM 参数
            4. 保存降维器、PGF 和 GPM 参数
        """
        # Do not initialize a writer
        writer = []

        logger.write('Training!')
        logging.info('Entering Training mode ...')
        logger.write('Entering Training mode ...')

        # 加载 Baseline 特征（通常是高维图像特征），最多取 args.num 个
        feature = np.load(self.train_feature_path)[:self.args['feature_save_num']]
        print("[Baseline GPM][Train] features Load")
        # 使用 Isomap 将特征降维为 PGF，并获取降维器
        PGF, self.ISO_fitter = ISOMap(feature)

        # 保存 Isomap 降维器
        with open(self.ISO_fitter_path, 'wb') as f:
            pickle.dump(self.ISO_fitter, f)
        print(f'[Baseline GPM] Isomap Fitter Saved: {self.ISO_fitter_path}')

        # 保存训练集的 PGF 特征
        np.save(self.train_PGF_path, PGF)
        print(f'[Baseline GPM] PGF Saved: {self.train_PGF_path}')

        # 读取对应训练 gaze label（原始 gaze_label 储存在 Baseline 的 json 日志中）
        with open(self.train_gt_path, 'r') as f:
            Baseline_log = json.load(f)

        # 用 PGF 和 gaze label 拟合线性映射模型（GPM），返回模型参数、预测值、拟合器
        self.GPM_param, gaze_pred, model = FitGaze(logger, PGF, np.array(Baseline_log['gaze_label'])[:self.args['feature_save_num']])

        # 保存 GPM 参数
        np.save(self.GPM_path, self.GPM_param)
        print(f'[Baseline GPM] GPM param Saved: {self.GPM_path}')

        # %% Closing functions and logging
        if writer:
            writer.close()

    def test(self, args, path_dict, logger):
        """
            执行 GPM 训练流程：
            1. 加载训练特征
            2. 使用 Isomap 进行降维 → PGF
            3. 用 PGF 和真实 gaze label 拟合 GPM 参数
            4. 保存降维器、PGF 和 GPM 参数
        """
        # Do not initialize a writer
        writer = []
        # %% Initialize logger
        logger.write('Testing!')
        logging.info('Entering Testing mode ...')
        logger.write('Entering Testing mode ...')

        # 加载 Baseline 特征（通常是高维图像特征），最多取 args.num 个
        feature = np.load(self.test_feature_path)[:self.args['feature_save_num']]
        print("[Baseline GPM][Test] features Load")
        # 加载测试 gaze label
        with open(self.test_gt_path, 'r') as f:
            test_gt = json.load(f)
        label = np.array(test_gt['gaze_label'][:self.args['feature_save_num']])

        # 对测试特征进行降维（使用训练时保存的 Isomap 降维器）
        test_PGF, _ = ISOMap(feature, fitter=self.ISO_fitter)

        # 使用训练好的 GPM 参数进行 gaze 预测
        _, gaze_pred, model = FitGaze(logger, test_PGF, label, self.GPM_param)

        # 保存测试集 PGF 特征
        np.save(self.test_PGF_path, test_PGF)
        print(f'[Baseline GPM] Test PGF Saved: {self.test_PGF_path}')

        # 计算 gaze 角度误差（单位：弧度）
        errors = angular_error(label, gaze_pred)
        # 构建日志：包含名称、真实标签、预测结果和误差
        log_dict = {
            'gaze_label': test_gt['gaze_label'],
            'gaze_GPM': gaze_pred.tolist(),
            'gaze_error': errors.tolist()
        }

        # 保存测试日志到 json 文件
        with open(self.test_log_path, 'w') as f:
            json.dump(log_dict, f)

        # %% Closing functions and logging
        if writer:
            writer.close()


if __name__ == '__main__':
    args = vars(make_args())
    GPMtrainer = GPModuleTrainer(args)
    GPMtrainer.run()

