import os
import json
import matplotlib
matplotlib.use('Agg')  # 使用非交互式后端
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
from mmcv.runner import HOOKS, Hook

@HOOKS.register_module()
class PCTClassifierLossPlotter(Hook):
    """专门记录PCT Classifier阶段关键指标的Hook"""
    
    def __init__(self, log_dir='C:/Users/USER/Downloads/PCT-main/loss_logs', interval=100, start_epoch=45):
        self.interval = interval  # 每多少次迭代记录一次
        self.base_log_dir = log_dir
        self.log_dir = None
        self.max_epochs = 270  # 添加这个属性来存储总epoch数
        self.start_epoch = start_epoch  # 记录起始epoch

        # 只存储PCT Classifier的关键指标
        self.loss_history = {
            'iterations': [],
            'epochs': [],
            'token_loss': [],
            'kpt_loss': [],
            'top1_acc': [],
            'learning_rate': [],
            'val_AP': [],
            'val_epochs': []
        }
        
    def before_run(self, runner):
        # 创建带时间戳的主文件夹
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_name = "PCT_Classifier"
        self.log_dir = os.path.join(self.base_log_dir, f"{model_name}_{timestamp}")
        
        # 保存max_epochs到实例变量
        self.max_epochs = getattr(runner, 'max_epochs', 270)
        
        os.makedirs(self.log_dir, exist_ok=True)
        runner.logger.info(f'PCT Classifier loss plots will be saved to {self.log_dir}')
        
        # 保存训练配置信息
        config_info = {
            'model_name': model_name,
            'start_time': timestamp,
            'work_dir': runner.work_dir,
            'total_epochs': self.max_epochs,
            'plot_type': 'PCT_Classifier_metrics'
        }
        
        with open(os.path.join(self.log_dir, 'config_info.json'), 'w') as f:
            json.dump(config_info, f, indent=2)
    
    def after_train_iter(self, runner):
        """训练迭代后记录PCT Classifier的关键指标"""
        if self.every_n_iters(runner, self.interval):
            log_vars = runner.log_buffer.output
            
            # 修改这里以匹配 tokenizer 的输出
            self.loss_history['iterations'].append(runner.iter)
            self.loss_history['epochs'].append(runner.epoch)
            # 'token_loss' -> 'e_latent_loss'
            self.loss_history['token_loss'].append(log_vars.get('e_latent_loss', 0)) 
            # 'kpt_loss' -> 'joint_loss'
            self.loss_history['kpt_loss'].append(log_vars.get('joint_loss', 0)) 
            # 'top1-acc' 在 tokenizer 阶段不存在，所以会是0
            self.loss_history['top1_acc'].append(log_vars.get('top1-acc', 0)) 
            self.loss_history['learning_rate'].append(log_vars.get('lr', 0))
    
    def after_train_epoch(self, runner):
        """训练epoch结束后生成当前累积图表"""
        if len(self.loss_history['iterations']) > 0:
            self.plot_pct_classifier_metrics(runner.epoch)
            self.save_epoch_data(runner.epoch)
    
    def after_val_epoch(self, runner):
        """验证epoch后记录指标"""
        log_vars = runner.log_buffer.output
        
        # 记录验证AP
        if 'AP' in log_vars:
            self.loss_history['val_AP'].append(log_vars['AP'])
            self.loss_history['val_epochs'].append(runner.epoch)
            
            runner.logger.info(f'Epoch {runner.epoch}: Validation AP = {log_vars["AP"]:.4f}')
    
    def plot_pct_classifier_metrics(self, current_epoch):
        """绘制PCT Classifier的关键指标"""
        try:
            # 创建当前epoch的文件夹
            epoch_dir = os.path.join(self.log_dir, f'epoch_{current_epoch:03d}')
            os.makedirs(epoch_dir, exist_ok=True)
            
            # 先创建图表和轴
            fig, axes = plt.subplots(2, 2, figsize=(16, 12))
            
            # 然后添加标题和注释
            start_epoch = self.start_epoch  # 使用实例变量而非硬编码
            fig.suptitle(f'PCT Classifier Metrics - Epochs {start_epoch}-{current_epoch}', 
                    fontsize=16, fontweight='bold')
            
            # 添加文字说明
            fig.text(0.5, 0.01, f"注意: 前{start_epoch-1}个epoch的数据未记录",
                    ha='center', color='red', fontsize=12)
            
            iterations = self.loss_history['iterations']
            epochs = self.loss_history['epochs']
            
            # 1. Token Loss
            if self.loss_history['token_loss'] and any(x > 0 for x in self.loss_history['token_loss']):
                axes[0, 0].plot(iterations, self.loss_history['token_loss'], 
                               'r-', linewidth=2, alpha=0.8)
                axes[0, 0].set_title(f'Token Loss (Epochs 1-{current_epoch})', fontweight='bold')
                axes[0, 0].set_xlabel('Iteration')
                axes[0, 0].set_ylabel('Token Loss')
                axes[0, 0].grid(True, alpha=0.3)
                
                # 添加epoch分界线
                self._add_epoch_lines(axes[0, 0], iterations, epochs)
                
                # 显示当前Token Loss值
                current_token_loss = self.loss_history['token_loss'][-1]
                axes[0, 0].text(0.02, 0.95, f'Current: {current_token_loss:.4f}', 
                               transform=axes[0, 0].transAxes, 
                               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral", alpha=0.7))
            
            # 2. Keypoint Loss
            if self.loss_history['kpt_loss'] and any(x > 0 for x in self.loss_history['kpt_loss']):
                axes[0, 1].plot(iterations, self.loss_history['kpt_loss'], 
                               'g-', linewidth=2, alpha=0.8)
                axes[0, 1].set_title(f'Keypoint Loss (Epochs 1-{current_epoch})', fontweight='bold')
                axes[0, 1].set_xlabel('Iteration')
                axes[0, 1].set_ylabel('Keypoint Loss')
                axes[0, 1].grid(True, alpha=0.3)
                
                self._add_epoch_lines(axes[0, 1], iterations, epochs)
                
                # 显示当前Keypoint Loss值
                current_kpt_loss = self.loss_history['kpt_loss'][-1]
                axes[0, 1].text(0.02, 0.95, f'Current: {current_kpt_loss:.4f}', 
                               transform=axes[0, 1].transAxes, 
                               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7))
            
            # 3. Top-1 Accuracy
            if self.loss_history['top1_acc'] and any(x > 0 for x in self.loss_history['top1_acc']):
                axes[1, 0].plot(iterations, self.loss_history['top1_acc'], 
                               'b-', linewidth=2, alpha=0.8)
                axes[1, 0].set_title(f'Top-1 Accuracy (Epochs 1-{current_epoch})', fontweight='bold')
                axes[1, 0].set_xlabel('Iteration')
                axes[1, 0].set_ylabel('Top-1 Accuracy (%)')
                axes[1, 0].grid(True, alpha=0.3)
                
                self._add_epoch_lines(axes[1, 0], iterations, epochs)
                
                # 显示当前Top-1 Accuracy值
                current_top1_acc = self.loss_history['top1_acc'][-1]
                best_top1_acc = max(self.loss_history['top1_acc'])
                axes[1, 0].text(0.02, 0.95, f'Current: {current_top1_acc:.2f}%', 
                               transform=axes[1, 0].transAxes, 
                               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
                axes[1, 0].text(0.02, 0.85, f'Best: {best_top1_acc:.2f}%', 
                               transform=axes[1, 0].transAxes, 
                               bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
            
            # 4. Validation AP
            if self.loss_history['val_AP'] and self.loss_history['val_epochs']:
                axes[1, 1].plot(self.loss_history['val_epochs'], 
                               self.loss_history['val_AP'], 
                               'purple', linewidth=3, marker='o', markersize=8, alpha=0.8)
                axes[1, 1].set_title(f'Validation AP (Epochs 1-{current_epoch})', fontweight='bold')
                axes[1, 1].set_xlabel('Epoch')
                axes[1, 1].set_ylabel('AP')
                axes[1, 1].grid(True, alpha=0.3)
                
                # 显示最佳AP和当前AP
                if self.loss_history['val_AP']:
                    best_ap = max(self.loss_history['val_AP'])
                    best_epoch = self.loss_history['val_epochs'][self.loss_history['val_AP'].index(best_ap)]
                    current_ap = self.loss_history['val_AP'][-1] if self.loss_history['val_AP'] else 0
                    
                    axes[1, 1].axhline(y=best_ap, color='red', linestyle='--', alpha=0.5)
                    axes[1, 1].text(0.02, 0.95, f'Best AP: {best_ap:.4f} (Epoch {best_epoch})', 
                                   transform=axes[1, 1].transAxes, 
                                   bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7))
                    axes[1, 1].text(0.02, 0.85, f'Current AP: {current_ap:.4f}', 
                                   transform=axes[1, 1].transAxes, 
                                   bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral", alpha=0.7))
            
            # 修复这里：使用self.max_epochs而不是访问runner
            fig.text(0.02, 0.02, f'PCT Classifier Training - Epoch {current_epoch}/{self.max_epochs} - Data Points: {len(iterations)}', 
                    fontsize=12, style='italic', fontweight='bold')
            
            plt.tight_layout()
            
            # 保存到当前epoch文件夹
            plt.savefig(os.path.join(epoch_dir, f'PCT_Classifier_metrics_epochs_{start_epoch}-{current_epoch}.png'), 
                       dpi=200, bbox_inches='tight')
            plt.close()  # 关闭图形以释放内存
            
        except Exception as e:
            print(f"Error plotting PCT Classifier metrics for epoch {current_epoch}: {e}")
    
    def _add_epoch_lines(self, ax, iterations, epochs):
        """在图上添加epoch分界线"""
        try:
            # 找到每个epoch的开始位置
            epoch_starts = {}
            for i, epoch in enumerate(epochs):
                if epoch not in epoch_starts:
                    epoch_starts[epoch] = iterations[i]
            
            # 添加垂直线标记epoch边界
            for epoch in sorted(epoch_starts.keys()):
                ax.axvline(x=epoch_starts[epoch], color='gray', linestyle=':', alpha=0.5, linewidth=0.8)
                
        except:
            pass  # 忽略错误，不影响主要绘图
    
    def save_epoch_data(self, current_epoch):
        """保存当前epoch的PCT Classifier数据"""
        try:
            # 当前epoch的文件夹
            epoch_dir = os.path.join(self.log_dir, f'epoch_{current_epoch:03d}')
            
            # 保存当前的累积数据
            epoch_data = {
                'epoch': current_epoch,
                'data_range': f'epochs_1_to_{current_epoch}',
                'total_iterations': len(self.loss_history['iterations']),
                'pct_classifier_metrics': self.loss_history.copy(),  # 完整的累积历史
                'epoch_statistics': {
                    'current_token_loss': self.loss_history['token_loss'][-1] if self.loss_history['token_loss'] else 0,
                    'current_kpt_loss': self.loss_history['kpt_loss'][-1] if self.loss_history['kpt_loss'] else 0,
                    'current_top1_acc': self.loss_history['top1_acc'][-1] if self.loss_history['top1_acc'] else 0,
                    'best_top1_acc': max(self.loss_history['top1_acc']) if self.loss_history['top1_acc'] else 0,
                    'avg_token_loss': np.mean([x for x in self.loss_history['token_loss'] if x > 0]) if any(x > 0 for x in self.loss_history['token_loss']) else 0,
                    'avg_kpt_loss': np.mean([x for x in self.loss_history['kpt_loss'] if x > 0]) if any(x > 0 for x in self.loss_history['kpt_loss']) else 0,
                    'current_learning_rate': self.loss_history['learning_rate'][-1] if self.loss_history['learning_rate'] else 0,
                    'best_val_AP_so_far': max(self.loss_history['val_AP']) if self.loss_history['val_AP'] else 0,
                    'current_val_AP': self.loss_history['val_AP'][-1] if self.loss_history['val_AP'] else 0,
                    'data_points_recorded': len(self.loss_history['iterations']),
                    'epochs_completed': len(set(self.loss_history['epochs'])) if self.loss_history['epochs'] else 0
                },
                'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            }
            
            # 保存到epoch文件夹
            with open(os.path.join(epoch_dir, f'PCT_Classifier_data_epochs_1-{current_epoch}.json'), 'w') as f:
                json.dump(epoch_data, f, indent=2)
            
            # 同时更新主目录的最新数据
            with open(os.path.join(self.log_dir, 'latest_PCT_Classifier_data.json'), 'w') as f:
                json.dump(epoch_data, f, indent=2)
                
        except Exception as e:
            print(f"Error saving PCT Classifier epoch {current_epoch} data: {e}")
    
    def after_run(self, runner):
        """训练结束后的清理工作"""
        # 创建训练总结
        total_epochs = len(set(self.loss_history['epochs'])) if self.loss_history['epochs'] else 0
        
        summary = {
            'training_completed': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            'model_type': 'PCT_Classifier',
            'total_epochs_completed': total_epochs,
            'total_iterations': len(self.loss_history['iterations']),
            'total_epoch_folders_created': total_epochs,
            'final_statistics': {
                'final_token_loss': self.loss_history['token_loss'][-1] if self.loss_history['token_loss'] else 0,
                'final_kpt_loss': self.loss_history['kpt_loss'][-1] if self.loss_history['kpt_loss'] else 0,
                'final_top1_acc': self.loss_history['top1_acc'][-1] if self.loss_history['top1_acc'] else 0,
                'best_top1_acc': max(self.loss_history['top1_acc']) if self.loss_history['top1_acc'] else 0,
                'best_val_AP': max(self.loss_history['val_AP']) if self.loss_history['val_AP'] else 0,
                'final_val_AP': self.loss_history['val_AP'][-1] if self.loss_history['val_AP'] else 0,
                'best_val_epoch': self.loss_history['val_epochs'][self.loss_history['val_AP'].index(max(self.loss_history['val_AP']))] if self.loss_history['val_AP'] else 0,
                'final_learning_rate': self.loss_history['learning_rate'][-1] if self.loss_history['learning_rate'] else 0,
                'data_points_recorded': len(self.loss_history['iterations'])
            },
            'folder_structure': [f'epoch_{i:03d}' for i in range(1, total_epochs + 1)]
        }
        
        with open(os.path.join(self.log_dir, 'PCT_Classifier_training_summary.json'), 'w') as f:
            json.dump(summary, f, indent=2)
        
        runner.logger.info(f'PCT Classifier training completed! {total_epochs} epoch folders with metrics saved to {self.log_dir}')