"""
trainer.py
模型训练器
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing import Dict, Optional, Callable, List, Tuple
import numpy as np
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import logging
import json
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)


class CircuitTrainer:
    """电路相似度模型训练器"""
    
    def __init__(self,
                 model: nn.Module,
                 train_loader: DataLoader,
                 val_loader: Optional[DataLoader] = None,
                 optimizer: Optional[Optimizer] = None,
                 scheduler: Optional[_LRScheduler] = None,
                 criterion: Optional[nn.Module] = None,
                 device: Optional[torch.device] = None,
                 config: Optional[Dict] = None):
        """
        Args:
            model: 要训练的模型
            train_loader: 训练数据加载器
            val_loader: 验证数据加载器
            optimizer: 优化器
            scheduler: 学习率调度器
            criterion: 损失函数
            device: 计算设备
            config: 训练配置
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # 设置设备
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # 配置
        self.config = config or {}
        
        # 优化器
        self.optimizer = optimizer or torch.optim.Adam(
            self.model.parameters(),
            lr=self.config.get('learning_rate', 0.001),
            weight_decay=self.config.get('weight_decay', 1e-4)
        )
        
        # 学习率调度器
        self.scheduler = scheduler or torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            patience=self.config.get('scheduler_patience', 5),
            factor=self.config.get('scheduler_factor', 0.5)
        )
        
        # 损失函数
        self.criterion = criterion or nn.MSELoss()
        
        # 训练历史
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'learning_rates': []
        }
        
        # 最佳模型跟踪
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        
        # 早停
        self.early_stopping_patience = self.config.get('early_stopping_patience', 10)
        self.early_stopping_counter = 0
        
        # 保存路径
        self.save_dir = Path(self.config.get('save_dir', './checkpoints'))
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"Trainer initialized on {self.device}")
    
    def train_epoch(self) -> float:
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm(self.train_loader, desc='Training')
        for batch in progress_bar:
            # 将数据移到设备
            graph1, matrix1, graph2, matrix2, distances = [x.to(self.device) for x in batch]
            
            # 前向传播
            embedding1 = self.model(graph1, matrix1)
            embedding2 = self.model(graph2, matrix2)
            
            # 计算预测距离
            pred_distances = torch.norm(embedding1 - embedding2, p=2, dim=1)
            
            # 计算损失
            loss = self.criterion(pred_distances, distances)
            
            # 添加正则化损失（可选）
            if self.config.get('use_regularization', False):
                reg_loss = self._compute_regularization_loss(embedding1, embedding2)
                loss = loss + self.config.get('reg_weight', 0.01) * reg_loss
            
            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            
            # 梯度裁剪（可选）
            if self.config.get('gradient_clip', None):
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config['gradient_clip']
                )
            
            self.optimizer.step()
            
            # 记录损失
            total_loss += loss.item()
            num_batches += 1
            
            # 更新进度条
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_loss = total_loss / num_batches
        return avg_loss
    
    def validate(self) -> Tuple[float, Dict]:
        """验证模型"""
        if self.val_loader is None:
            return 0.0, {}
        
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc='Validation'):
                # 将数据移到设备
                graph1, matrix1, graph2, matrix2, distances = [x.to(self.device) for x in batch]
                
                # 前向传播
                embedding1 = self.model(graph1, matrix1)
                embedding2 = self.model(graph2, matrix2)
                
                # 计算预测距离
                pred_distances = torch.norm(embedding1 - embedding2, p=2, dim=1)
                
                # 计算损失
                loss = self.criterion(pred_distances, distances)
                
                total_loss += loss.item()
                num_batches += 1
                
                # 收集预测和目标
                all_predictions.extend(pred_distances.cpu().numpy())
                all_targets.extend(distances.cpu().numpy())
        
        avg_loss = total_loss / num_batches
        
        # 计算额外的评估指标
        metrics = self._compute_metrics(
            np.array(all_predictions),
            np.array(all_targets)
        )
        
        return avg_loss, metrics
    
    def _compute_metrics(self, predictions: np.ndarray, targets: np.ndarray) -> Dict:
        """计算评估指标"""
        from sklearn.metrics import mean_absolute_error, r2_score
        from scipy.stats import pearsonr, spearmanr
        
        mae = mean_absolute_error(targets, predictions)
        mse = np.mean((predictions - targets) ** 2)
        rmse = np.sqrt(mse)
        
        # 相关系数
        if len(predictions) > 1:
            pearson_corr, _ = pearsonr(targets, predictions)
            spearman_corr, _ = spearmanr(targets, predictions)
            r2 = r2_score(targets, predictions)
        else:
            pearson_corr = spearman_corr = r2 = 0.0
        
        return {
            'mae': mae,
            'mse': mse,
            'rmse': rmse,
            'pearson': pearson_corr,
            'spearman': spearman_corr,
            'r2': r2
        }
    
    def _compute_regularization_loss(self, embedding1: torch.Tensor, embedding2: torch.Tensor) -> torch.Tensor:
        """计算正则化损失（例如：对比学习损失）"""
        # 示例：使嵌入向量更加分散
        batch_size = embedding1.size(0)
        
        # 计算批内所有嵌入的相似度
        all_embeddings = torch.cat([embedding1, embedding2], dim=0)
        similarity_matrix = torch.mm(all_embeddings, all_embeddings.t())
        
        # 排除对角线
        mask = torch.eye(2 * batch_size, device=self.device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, 0)
        
        # 最大化嵌入之间的距离
        reg_loss = -torch.mean(similarity_matrix)
        
        return reg_loss
    
    def train(self, num_epochs: int) -> Dict:
        """
        训练模型
        Args:
            num_epochs: 训练轮数
        Returns:
            训练历史
        """
        logger.info(f"Starting training for {num_epochs} epochs")
        
        for epoch in range(1, num_epochs + 1):
            logger.info(f"\nEpoch {epoch}/{num_epochs}")
            
            # 训练
            train_loss = self.train_epoch()
            self.history['train_loss'].append(train_loss)
            logger.info(f"Train Loss: {train_loss:.4f}")
            
            # 验证
            val_loss, val_metrics = self.validate()
            self.history['val_loss'].append(val_loss)
            logger.info(f"Val Loss: {val_loss:.4f}")
            logger.info(f"Val Metrics: {val_metrics}")
            
            # 学习率调度
            current_lr = self.optimizer.param_groups[0]['lr']
            self.history['learning_rates'].append(current_lr)
            
            if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step(val_loss)
            else:
                self.scheduler.step()
            
            # 保存最佳模型
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)
                self.early_stopping_counter = 0
                logger.info(f"New best model saved (val_loss: {val_loss:.4f})")
            else:
                self.early_stopping_counter += 1
            
            # 早停
            if self.early_stopping_counter >= self.early_stopping_patience:
                logger.info(f"Early stopping triggered after {epoch} epochs")
                break
            
            # 定期保存检查点
            if epoch % self.config.get('save_every', 10) == 0:
                self.save_checkpoint(epoch, is_best=False)
        
        # 训练结束，绘制训练曲线
        self.plot_training_history()
        
        return self.history
    
    def save_checkpoint(self, epoch: int, is_best: bool = False):
        """保存模型检查点"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'best_val_loss': self.best_val_loss,
            'history': self.history,
            'config': self.config
        }
        
        # 保存检查点
        if is_best:
            path = self.save_dir / 'best_model.pt'
        else:
            path = self.save_dir / f'checkpoint_epoch_{epoch}.pt'
        
        torch.save(checkpoint, path)
        logger.info(f"Checkpoint saved to {path}")
    
    def load_checkpoint(self, checkpoint_path: str):
        """加载模型检查点"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if self.scheduler and checkpoint.get('scheduler_state_dict'):
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        self.history = checkpoint.get('history', {})
        
        logger.info(f"Checkpoint loaded from {checkpoint_path}")
        return checkpoint['epoch']
    
    def plot_training_history(self):
        """绘制训练历史"""
        if not self.history['train_loss']:
            return
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        # 损失曲线
        axes[0].plot(self.history['train_loss'], label='Train Loss')
        if self.history['val_loss']:
            axes[0].plot(self.history['val_loss'], label='Val Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True)
        
        # 学习率曲线
        if self.history['learning_rates']:
            axes[1].plot(self.history['learning_rates'])
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Learning Rate')
            axes[1].set_title('Learning Rate Schedule')
            axes[1].grid(True)
        
        plt.tight_layout()
        
        # 保存图像
        plot_path = self.save_dir / 'training_history.png'
        plt.savefig(plot_path, dpi=100)
        plt.close()
        
        logger.info(f"Training history plot saved to {plot_path}")


class ContrastiveTrainer(CircuitTrainer):
    """对比学习训练器（扩展版本）"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # 对比学习特定的损失函数
        self.contrastive_margin = self.config.get('contrastive_margin', 1.0)
        self.criterion = self._contrastive_loss
    
    def _contrastive_loss(self, pred_distances: torch.Tensor, target_distances: torch.Tensor) -> torch.Tensor:
        """对比损失函数"""
        # 将目标距离转换为相似/不相似标签
        threshold = self.config.get('similarity_threshold', 0.5)
        similar = (target_distances < threshold).float()
        
        # 对比损失
        loss_similar = similar * pred_distances.pow(2)
        loss_dissimilar = (1 - similar) * torch.clamp(self.contrastive_margin - pred_distances, min=0).pow(2)
        
        loss = torch.mean(loss_similar + loss_dissimilar)
        return loss