"""
main.py
主程序入口 - 整合训练和测试流程
"""

import torch
import argparse
import logging
import json
from pathlib import Path
import random
import numpy as np

from model import CircuitDistanceModel
from data_loader import create_dataloaders
from trainer import CircuitTrainer
from tester import CircuitTester

# 设置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def set_seed(seed: int):
    """设置随机种子"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def load_config(config_path: str) -> dict:
    """加载配置文件"""
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config


def main():
    parser = argparse.ArgumentParser(description='Circuit Similarity Learning')
    parser.add_argument('--mode', type=str, choices=['train', 'test', 'both'], 
                       default='both', help='运行模式')
    parser.add_argument('--config', type=str, default='config.json', 
                       help='配置文件路径')
    parser.add_argument('--checkpoint', type=str, default=None,
                       help='检查点路径（用于恢复训练或测试）')
    
    args = parser.parse_args()
    
    # 加载配置
    config = load_config(args.config)
    
    # 设置随机种子
    set_seed(config.get('seed', 42))
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    # 创建数据加载器
    train_loader, val_loader, test_loader = create_dataloaders(config)
    
    if args.mode in ['train', 'both']:
        # 创建模型
        model = CircuitDistanceModel(
            node_feature_dim=config.get('node_feature_dim', 7),
            embedding_dim=config.get('embedding_dim', 256),
            fusion_method=config.get('fusion_method', 'concat')
        )
        
        # 创建训练器
        trainer = CircuitTrainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            config=config
        )
        
        # 如果提供了检查点，加载它
        start_epoch = 0
        if args.checkpoint:
            start_epoch = trainer.load_checkpoint(args.checkpoint)
        
        # 训练
        trainer.train(config['num_epochs'] - start_epoch)
        
        # 保存最终模型
        trainer.save_checkpoint(config['num_epochs'], is_best=False)
    
    if args.mode in ['test', 'both']:
        # 确定模型路径
        if args.mode == 'both':
            model_path = Path(config.get('save_dir', './checkpoints')) / 'best_model.pt'
        else:
            model_path = args.checkpoint
        
        if not model_path or not Path(model_path).exists():
            logger.error("Model checkpoint not found for testing")
            return
        
        # 加载模型
        model = CircuitDistanceModel(
            node_feature_dim=config.get('node_feature_dim', 7),
            embedding_dim=config.get('embedding_dim', 256),
            fusion_method=config.get('fusion_method', 'concat')
        )
        
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # 创建测试器
        tester = CircuitTester(
            model=model,
            test_loader=test_loader,
            device=device,
            save_dir=config.get('test_results_dir', './test_results')
        )
        
        # 运行测试
        results = tester.test()
        
        # 打印主要结果
        logger.info("Test Results:")
        logger.info(f"  MAE: {results['mae']:.4f}")
        logger.info(f"  RMSE: {results['rmse']:.4f}")
        logger.info(f"  Pearson Correlation: {results.get('pearson_corr', 0):.4f}")
        logger.info(f"  R²: {results['r2']:.4f}")


if __name__ == "__main__":
    main()