# train_with_topo_loss.py

#!/usr/bin/env python3
"""
INP-Former based Texture Classification Network
基于INP-Former的纹理分类网络 - 两阶段训练

第一阶段：INP特征提取器训练（使用gather loss）
第二阶段：分类器训练（使用INP特征）

参考INP-Former的建模流程
"""

import os
import sys
import time
import argparse
import logging
import json
import math
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Union
from functools import partial
import numpy as np
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn.init import trunc_normal_

# 导入数据加载模块
from multi_dataset import create_multi_dataloaders
# [新增] 导入拓扑损失函数
from topology_tools import calculate_supervised_topological_loss

# ================================ INP-Former模块定义 ================================

class Mlp(nn.Module):
    """多层感知机模块"""
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample"""
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class Aggregation_Attention(nn.Module):
    """聚合注意力模块 - 用于INP提取"""
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, y):
        B, T, C = x.shape
        _, N, _ = y.shape
        q = self.q(x).reshape(B, T, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)[0]
        kv = self.kv(y).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attnmap = attn.softmax(dim=-1)
        attn = self.attn_drop(attnmap)
        x = (attn @ v).transpose(1, 2).reshape(B, T, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Aggregation_Block(nn.Module):
    """聚合模块 - 用于INP提取"""
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Aggregation_Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, y):
        x = x + self.drop_path(self.attn(self.norm1(x), self.norm1(y)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class INP_Extractor(nn.Module):
    """
    INP特征提取器 - 第一阶段训练
    基于INP-Former的设计，使用gather loss训练
    """
    
    def __init__(
        self,
        dinov2_model_name: str = 'dinov2_vitb14',
        num_inp: int = 16,  # INP数量
        target_layers: List[int] = [2, 3, 4, 5, 6, 7, 8, 9],
        freeze_backbone: bool = False,
        unfrozen_blocks: int = 4
    ):
        super().__init__()
        self.num_inp = num_inp
        self.target_layers = target_layers
        
        # 加载DINOv2主干网络
        print(f"Loading DINOv2 model: {dinov2_model_name}")
        self.dinov2_model = self._load_dinov2_model(dinov2_model_name)
        
        # 获取DINOv2特征维度
        if 'vits14' in dinov2_model_name:
            self.dinov2_dim = 384
            self.num_heads = 6
        elif 'vitb14' in dinov2_model_name:
            self.dinov2_dim = 768
            self.num_heads = 12
        elif 'vitl14' in dinov2_model_name:
            self.dinov2_dim = 1024
            self.num_heads = 16
        else:
            self.dinov2_dim = 768
            self.num_heads = 12
        
        # 冻结策略
        if freeze_backbone:
            for param in self.dinov2_model.parameters():
                param.requires_grad = False
            print("DINOv2 backbone fully frozen.")
        else:
            self._partial_freeze_dinov2(num_unfrozen_blocks=unfrozen_blocks)
        
        # 如果没有注册token属性，添加一个
        if not hasattr(self.dinov2_model, 'num_register_tokens'):
            self.dinov2_model.num_register_tokens = 0
        
        # INP参数 - 可学习的内在纹理原型
        self.inp_prototypes = nn.Parameter(torch.randn(num_inp, self.dinov2_dim))
        
        # 瓶颈层 - 特征处理
        self.bottleneck = Mlp(self.dinov2_dim, self.dinov2_dim * 4, self.dinov2_dim, drop=0.)
        
        # INP聚合器 - 用于提取INP特征
        self.inp_aggregator = Aggregation_Block(
            dim=self.dinov2_dim, 
            num_heads=self.num_heads, 
            mlp_ratio=4.,
            qkv_bias=True, 
            norm_layer=partial(nn.LayerNorm, eps=1e-8)
        )
        
        # 权重初始化
        self._initialize_weights()
        
        print(f"INP_Extractor initialized: {num_inp} INP prototypes, dim={self.dinov2_dim}")
    
    def _load_dinov2_model(self, model_name: str) -> nn.Module:
        """加载DINOv2模型"""
        try:
            model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
            print(f"✅ DINOv2 loaded via torch.hub")
            return model
        except Exception as e:
            print(f"❌ torch.hub loading failed: {e}")
            return self._create_mock_dinov2(model_name)
    
    def _create_mock_dinov2(self, model_name: str) -> nn.Module:
        """创建Mock DINOv2模型用于测试"""
        class MockDINOv2(nn.Module):
            def __init__(self, embed_dim: int):
                super().__init__()
                self.embed_dim = embed_dim
                self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=16, stride=16)
                self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
                self.pos_embed = nn.Parameter(torch.randn(1, 197, embed_dim))
                
                self.blocks = nn.ModuleList([
                    nn.TransformerEncoderLayer(
                        d_model=embed_dim,
                        nhead=8,
                        dim_feedforward=embed_dim * 4,
                        dropout=0.1,
                        batch_first=True
                    ) for _ in range(12)
                ])
                self.norm = nn.LayerNorm(embed_dim)
                self.num_register_tokens = 0
            
            def prepare_tokens(self, x):
                B, C, H, W = x.shape
                x = self.patch_embed(x)
                x = x.flatten(2).transpose(1, 2)
                
                cls_tokens = self.cls_token.expand(B, -1, -1)
                x = torch.cat([cls_tokens, x], dim=1)
                x = x + self.pos_embed
                return x
            
            def forward(self, x):
                x = self.prepare_tokens(x)
                for block in self.blocks:
                    x = block(x)
                return self.norm(x)
        
        if 'vits14' in model_name:
            embed_dim = 384
        elif 'vitb14' in model_name:
            embed_dim = 768
        elif 'vitl14' in model_name:
            embed_dim = 1024
        else:
            embed_dim = 768
        
        print(f"✅ Mock DINOv2 created with embed_dim={embed_dim}")
        return MockDINOv2(embed_dim)
    
    def _partial_freeze_dinov2(self, num_unfrozen_blocks: int = 4):
        """部分冻结DINOv2"""
        for param in self.dinov2_model.parameters():
            param.requires_grad = False
        
        if hasattr(self.dinov2_model, 'blocks'):
            total_blocks = len(self.dinov2_model.blocks)
            
            if num_unfrozen_blocks == -1 or num_unfrozen_blocks >= total_blocks:
                for param in self.dinov2_model.parameters():
                    param.requires_grad = True
                print(f"DINOv2: Fully unfrozen. All {total_blocks} blocks are trainable.")
            else:
                unfrozen_start_index = max(0, total_blocks - num_unfrozen_blocks)
                
                for i in range(unfrozen_start_index, total_blocks):
                    for param in self.dinov2_model.blocks[i].parameters():
                        param.requires_grad = True
                
                if hasattr(self.dinov2_model, 'norm'):
                    for param in self.dinov2_model.norm.parameters():
                        param.requires_grad = True
                
                print(f"DINOv2: Partially frozen. Unfreezing last {num_unfrozen_blocks}/{total_blocks} blocks.")
            
            trainable_params = sum(p.numel() for p in self.dinov2_model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in self.dinov2_model.parameters())
            print(f"DINOv2 trainable parameters: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.1f}%)")
    
    def _initialize_weights(self):
        """权重初始化"""
        # 初始化INP参数
        trunc_normal_(self.inp_prototypes, std=0.01, a=-0.03, b=0.03)
        
        # 初始化其他模块
        for m in [self.bottleneck, self.inp_aggregator]:
            for module in m.modules():
                if isinstance(module, nn.Linear):
                    trunc_normal_(module.weight, std=0.01, a=-0.03, b=0.03)
                    if module.bias is not None:
                        nn.init.constant_(module.bias, 0)
                elif isinstance(module, nn.LayerNorm):
                    nn.init.constant_(module.bias, 0)
                    nn.init.constant_(module.weight, 1.0)
    
    def extract_features(self, images: torch.Tensor) -> List[torch.Tensor]:
        """
        提取多层特征, 并正确处理位置编码插值
        """
        B, C, H, W = images.shape
        
        # 1. 获取patch embedding
        x = self.dinov2_model.patch_embed(images)
        num_patches = x.shape[1]
        
        # 2. 添加CLS token
        cls_tokens = self.dinov2_model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # 3. 对位置编码进行插值以匹配输入尺寸
        pos_embed_cls = self.dinov2_model.pos_embed[:, 0:1, :]
        pos_embed_patch = self.dinov2_model.pos_embed[:, 1:, :]
        
        orig_num_patches = pos_embed_patch.shape[1]
        orig_side = int(math.sqrt(orig_num_patches))
        
        new_side_h = H // self.dinov2_model.patch_embed.patch_size[0]
        new_side_w = W // self.dinov2_model.patch_embed.patch_size[1]
        
        if orig_side * orig_side != num_patches:
            pos_embed_patch = pos_embed_patch.reshape(1, orig_side, orig_side, -1).permute(0, 3, 1, 2)
            pos_embed_patch = F.interpolate(
                pos_embed_patch,
                size=(new_side_h, new_side_w),
                mode='bicubic',
                align_corners=False
            )
            pos_embed_patch = pos_embed_patch.permute(0, 2, 3, 1).flatten(1, 2)
            
        new_pos_embed = torch.cat((pos_embed_cls, pos_embed_patch), dim=1)
        x = x + new_pos_embed
        
        en_list = []
        for i, blk in enumerate(self.dinov2_model.blocks):
            if i <= self.target_layers[-1]:
                x = blk(x)
            else:
                continue
            
            if i in self.target_layers:
                en_list.append(x)
        
        return en_list
    
    def fuse_features(self, feat_list: List[torch.Tensor]) -> torch.Tensor:
        """融合多层特征"""
        return torch.stack(feat_list, dim=1).mean(dim=1)
    
    def gather_loss(self, query: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
        """计算gather loss - INP-Former的核心损失"""
        distribution = 1. - F.cosine_similarity(query.unsqueeze(2), keys.unsqueeze(1), dim=-1)
        distance, cluster_index = torch.min(distribution, dim=2)
        gather_loss = distance.mean()
        
        return gather_loss
    
    def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        前向传播 - 第一阶段训练
        Returns:
            features: 融合特征 [B, N, D]
            inp_features: INP特征 [B, M, D]
            gather_loss: 聚合损失
        """
        B = images.size(0)
        
        en_list = self.extract_features(images)
        en_list = [e[:, 1 + self.dinov2_model.num_register_tokens:, :] for e in en_list]
        x = self.fuse_features(en_list)
        x = self.bottleneck(x)
        inp_prototypes = self.inp_prototypes.unsqueeze(0).repeat(B, 1, 1)
        inp_features = self.inp_aggregator(inp_prototypes, x)
        g_loss = self.gather_loss(x, inp_features)
        
        return x, inp_features, g_loss


class INP_Classifier(nn.Module):
    """
    INP分类器 - 第二阶段训练
    使用预训练的INP特征进行分类
    """
    
    def __init__(
        self,
        inp_extractor: INP_Extractor,
        num_classes: int,
        dropout: float = 0.1
    ):
        super().__init__()
        self.inp_extractor = inp_extractor
        self.num_classes = num_classes
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.inp_extractor.dinov2_dim),
            nn.Dropout(dropout),
            nn.Linear(self.inp_extractor.dinov2_dim, num_classes)
        )
        self._initialize_weights()
    
    def _initialize_weights(self):
        """权重初始化"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    # [修改] forward方法现在需要返回logits和用于计算拓扑损失的特征
    def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向传播 - 第二阶段分类
        """
        _, inp_features, _ = self.inp_extractor(images)
        
        # 全局平均池化INP特征
        global_feature = torch.mean(inp_features, dim=1)  # [B, D]
        
        # 分类
        logits = self.classifier(global_feature)  # [B, num_classes]
        
        return logits, global_feature


# ================================ 训练函数 ================================

def train_inp_extractor(
    model: INP_Extractor,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch: int,
    logger: logging.Logger,
    writer: SummaryWriter = None
) -> float:
    """训练INP提取器 - 第一阶段"""
    model.train()
    
    total_loss = 0.0
    total_samples = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch} - INP Extractor Training')
    
    for batch_idx, batch in enumerate(progress_bar):
        images = batch['image'].to(device)
        batch_size = images.size(0)
        
        optimizer.zero_grad()
        
        try:
            # 前向传播
            features, inp_features, gather_loss = model(images)
            
            # 总损失就是gather loss
            total_loss_batch = gather_loss
            
            # 反向传播
            total_loss_batch.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            
            optimizer.step()
            
            # 更新统计
            total_loss += total_loss_batch.item() * batch_size
            total_samples += batch_size
            
            # 更新进度条
            progress_bar.set_postfix({
                'gather_loss': f'{total_loss_batch.item():.4f}',
                'avg_loss': f'{total_loss/total_samples:.4f}'
            })
            
        except Exception as e:
            logger.error(f"训练步骤出错: {e}")
            continue
        
        # 记录到TensorBoard
        if writer and batch_idx % 50 == 0:
            global_step = epoch * len(dataloader) + batch_idx
            writer.add_scalar('Stage1/Gather_Loss_Step', total_loss_batch.item(), global_step)
    
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    logger.info(f'Epoch {epoch} - INP Extractor Training - Avg Gather Loss: {avg_loss:.4f}')
    
    return avg_loss


# [修改] 第二阶段训练函数，加入监督拓扑损失
def train_classifier(
    model: INP_Classifier,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch: int,
    logger: logging.Logger,
    writer: SummaryWriter = None,
    lambda_topo: float = 0.1
) -> Tuple[float, float]:
    """训练分类器 - 第二阶段"""
    model.train()
    
    total_loss = 0.0
    total_accuracy = 0.0
    total_samples = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch} - Classifier Training')
    
    for batch_idx, batch in enumerate(progress_bar):
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        batch_size = images.size(0)
        
        optimizer.zero_grad()
        
        try:
            # 前向传播
            logits, features = model(images)
            
            # 计算损失
            ce_loss = criterion(logits, labels)
            topo_loss = calculate_supervised_topological_loss(features, labels)
            loss = ce_loss + lambda_topo * topo_loss
            
            # 计算准确率
            predictions = torch.argmax(logits, dim=1)
            accuracy = (predictions == labels).float().mean().item()
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # 更新统计
            total_loss += loss.item() * batch_size
            total_accuracy += accuracy * batch_size
            total_samples += batch_size
            
            # 更新进度条
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'ce_loss': f'{ce_loss.item():.4f}',
                'topo_loss': f'{topo_loss.item():.4f}',
                'acc': f'{accuracy:.4f}',
                'avg_loss': f'{total_loss/total_samples:.4f}',
                'avg_acc': f'{total_accuracy/total_samples:.4f}'
            })
            
        except Exception as e:
            logger.error(f"训练步骤出错: {e}")
            continue
        
        # 记录到TensorBoard
        if writer and batch_idx % 50 == 0:
            global_step = epoch * len(dataloader) + batch_idx
            writer.add_scalar('Stage2/Total_Loss_Step', loss.item(), global_step)
            writer.add_scalar('Stage2/CE_Loss_Step', ce_loss.item(), global_step)
            writer.add_scalar('Stage2/Topo_Loss_Step', topo_loss.item(), global_step)
            writer.add_scalar('Stage2/Accuracy_Step', accuracy, global_step)
    
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    avg_accuracy = total_accuracy / total_samples if total_samples > 0 else 0.0
    
    logger.info(f'Epoch {epoch} - Classifier Training - Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}')
    
    return avg_loss, avg_accuracy


# [修改] 评估函数，以正确处理分类器返回的两个值
def evaluate_model(
    model: INP_Classifier,
    dataloader: DataLoader,
    device: torch.device,
    logger: logging.Logger
) -> Tuple[float, float]:
    """评估模型"""
    model.eval()
    
    total_correct = 0
    total_samples = 0
    total_loss = 0.0
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Evaluating')
        
        for batch in progress_bar:
            try:
                images = batch['image'].to(device)
                labels = batch['label'].to(device)
                
                logits, _ = model(images)
                
                loss = criterion(logits, labels)
                predictions = torch.argmax(logits, dim=1)
                correct = (predictions == labels).sum().item()
                
                batch_size = images.size(0)
                total_loss += loss.item() * batch_size
                total_correct += correct
                total_samples += batch_size
                
                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{correct/batch_size:.4f}'
                })
                
            except Exception as e:
                logger.error(f"评估步骤出错: {e}")
                continue
    
    avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
    avg_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    
    logger.info(f'Evaluation - Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}')
    
    return avg_loss, avg_accuracy


# ================================ 工具函数 ================================

def setup_logging(save_dir: str) -> logging.Logger:
    """设置日志系统"""
    logger = logging.getLogger('INP_Former_Texture')
    logger.setLevel(logging.INFO)
    
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    
    file_handler = logging.FileHandler(os.path.join(save_dir, 'training.log'))
    file_handler.setLevel(logging.INFO)
    
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    return logger


def save_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    epoch: int,
    best_metric: float,
    save_path: str,
    is_best: bool = False
):
    """保存检查点"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_metric': best_metric,
        'timestamp': datetime.now().isoformat()
    }
    
    torch.save(checkpoint, save_path)
    
    if is_best:
        best_path = save_path.replace('.pth', '_best.pth')
        torch.save(checkpoint, best_path)


# ================================ 主训练脚本 ================================

def main():
    """主函数"""
    
    parser = argparse.ArgumentParser(description='INP-Former based Texture Classification - Two Stage Training')
    
    # 数据集参数
    parser.add_argument('--gtos_root', type=str, default='/data2/bpeng/dataset/gtos/gtos')
    parser.add_argument('--dtd_root', type=str, default='/home/tione/notebook/research/luciuspan/TDA-cls/dtd')
    parser.add_argument('--use_datasets', type=str, nargs='+', default=['dtd'],
                        choices=['gtos', 'dtd'])
    parser.add_argument('--split_id', type=int, default=1)
    parser.add_argument('--img_size', type=int, default=518)
    parser.add_argument('--few_shot', type=int, default=-1)
    
    # 训练阶段控制
    parser.add_argument('--stage', type=str, default='stage2', choices=['stage1', 'stage2', 'all'],
                        help='训练阶段: stage1(INP提取器), stage2(分类器), all(两阶段)')
    parser.add_argument('--stage1_epochs', type=int, default=3, help='第一阶段训练轮数')
    parser.add_argument('--stage2_epochs', type=int, default=100, help='第二阶段训练轮数')
    parser.add_argument('--stage1_checkpoint', type=str, default='/home/tione/notebook/research/luciuspan/TDA-cls/experiments/inp_former/dtd_inp_former_518px_20250910_130018/stage1_best.pth', 
                        help='第一阶段检查点路径(用于第二阶段)')
    
    # 模型参数
    parser.add_argument('--dinov2_model', type=str, default='dinov2_vitb14',
                        choices=['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14'])
    parser.add_argument('--num_inp', type=int, default=16, help='INP原型数量')
    parser.add_argument('--target_layers', type=int, nargs='+', default=[2, 3, 4, 5, 6, 7, 8, 9],
                        help='目标层')
    parser.add_argument('--freeze_backbone', action='store_true', help='完全冻结DINOv2主干')
    parser.add_argument('--unfrozen_blocks', type=int, default=4)
    parser.add_argument('--unfreeze_all', action='store_true')
    
    # 训练参数
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--stage1_lr', type=float, default=1e-5, help='第一阶段学习率')
    parser.add_argument('--stage2_lr', type=float, default=1e-4, help='第二阶段学习率')
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    
    # [新增] 监督拓扑损失的权重
    parser.add_argument('--stage2_lambda_topo', type=float, default=0.1, help='Weight for supervised topological loss in Stage 2')

    # 系统参数
    parser.add_argument('--gpu_id', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--save_dir', type=str, default='./experiments/inp_former1')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--eval_interval', type=int, default=1)
    
    args = parser.parse_args()
    
    # 验证图像尺寸
    if args.img_size % 14 != 0:
        closest_size = ((args.img_size + 6) // 14) * 14
        args.img_size = closest_size
        print(f"✅ 调整图像尺寸为: {args.img_size}x{args.img_size}")
    
    # 处理解冻参数
    unfrozen_blocks = args.unfrozen_blocks
    if args.unfreeze_all:
        unfrozen_blocks = -1
    
    # 设置随机种子
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # 设置设备
    device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 创建实验目录
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    dataset_name = '_'.join(args.use_datasets)
    save_dir = os.path.join(args.save_dir, f'{dataset_name}_inp_former_{args.img_size}px_{timestamp}')
    os.makedirs(save_dir, exist_ok=True)
    
    # 保存配置
    with open(os.path.join(save_dir, 'config.json'), 'w') as f:
        json.dump(vars(args), f, indent=2)
    
    # 设置日志
    logger = setup_logging(save_dir)
    logger.info("🚀 INP-Former纹理分类训练启动")
    logger.info(f"训练阶段: {args.stage}")
    logger.info(f"图像尺寸: {args.img_size}x{args.img_size}")
    logger.info(f"INP原型数量: {args.num_inp}")
    
    # 设置TensorBoard
    writer = SummaryWriter(os.path.join(save_dir, 'tensorboard'))
    
    # 准备数据集
    dataset_configs = []
    num_classes = 0
    
    for dataset_name in args.use_datasets:
        if dataset_name == 'gtos':
            dataset_configs.extend([
                {
                    'dataset_type': 'gtos',
                    'data_root': args.gtos_root,
                    'split': 'train',
                    'img_size': args.img_size,
                    'few_shot': args.few_shot
                },
                {
                    'dataset_type': 'gtos',
                    'data_root': args.gtos_root,
                    'split': 'test',
                    'img_size': args.img_size,
                    'few_shot': -1
                }
            ])
            num_classes = 39
        elif dataset_name == 'dtd':
            dataset_configs.extend([
                {
                    'dataset_type': 'dtd',
                    'data_root': args.dtd_root,
                    'split': 'train',
                    'img_size': args.img_size,
                    'few_shot': args.few_shot
                },
                {
                    'dataset_type': 'dtd',
                    'data_root': args.dtd_root,
                    'split': 'test',
                    'img_size': args.img_size,
                    'few_shot': -1
                }
            ])
            num_classes = 47
    
    logger.info("创建数据加载器...")
    dataloaders = create_multi_dataloaders(
        dataset_configs,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        split_id=args.split_id
    )
    
    train_loader = None
    test_loaders = {}
    
    for name, loader in dataloaders.items():
        if 'train' in name:
            if train_loader is None:
                train_loader = loader
            logger.info(f"训练集 {name}: {len(loader.dataset)} 样本")
        else:
            test_loaders[name] = loader
            logger.info(f"测试集 {name}: {len(loader.dataset)} 样本")
    
    if train_loader is None:
        raise ValueError("未找到训练数据集")
    
    # ==================== 第一阶段：INP提取器训练 ====================
    
    if args.stage in ['stage1', 'all']:
        logger.info("=" * 80)
        logger.info("第一阶段：INP特征提取器训练")
        logger.info("=" * 80)
        
        inp_extractor = INP_Extractor(
            dinov2_model_name=args.dinov2_model,
            num_inp=args.num_inp,
            target_layers=args.target_layers,
            freeze_backbone=args.freeze_backbone,
            unfrozen_blocks=unfrozen_blocks
        ).to(device)
        
        trainable_params = []
        trainable_params.extend(inp_extractor.bottleneck.parameters())
        trainable_params.extend(inp_extractor.inp_aggregator.parameters())
        trainable_params.append(inp_extractor.inp_prototypes)
        
        for param in inp_extractor.dinov2_model.parameters():
            if param.requires_grad:
                trainable_params.append(param)
        
        stage1_optimizer = optim.AdamW(
            trainable_params,
            lr=args.stage1_lr,
            weight_decay=args.weight_decay,
            betas=(0.9, 0.999),
            amsgrad=True,
            eps=1e-10
        )
        
        stage1_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            stage1_optimizer, T_max=args.stage1_epochs
        )
        
        total_params = sum(p.numel() for p in trainable_params)
        logger.info(f"第一阶段可训练参数: {total_params:,}")
        
        best_gather_loss = float('inf')
        
        for epoch in range(args.stage1_epochs):
            gather_loss = train_inp_extractor(
                inp_extractor, train_loader, stage1_optimizer, device,
                epoch, logger, writer
            )
            
            stage1_scheduler.step()
            current_lr = stage1_optimizer.param_groups[0]['lr']
            
            writer.add_scalar('Stage1/Gather_Loss_Epoch', gather_loss, epoch)
            writer.add_scalar('Stage1/Learning_Rate', current_lr, epoch)
            
            if gather_loss < best_gather_loss:
                best_gather_loss = gather_loss
                checkpoint_path = os.path.join(save_dir, 'stage1_best.pth')
                save_checkpoint(inp_extractor, stage1_optimizer, epoch, best_gather_loss, checkpoint_path, is_best=True)
                logger.info(f"✅ 保存第一阶段最佳模型 (Gather Loss: {best_gather_loss:.4f})")
            
            if epoch % 20 == 0:
                checkpoint_path = os.path.join(save_dir, f'stage1_epoch_{epoch}.pth')
                save_checkpoint(inp_extractor, stage1_optimizer, epoch, best_gather_loss, checkpoint_path)
        
        final_path = os.path.join(save_dir, 'stage1_final.pth')
        save_checkpoint(inp_extractor, stage1_optimizer, args.stage1_epochs-1, best_gather_loss, final_path)
        
        logger.info(f"第一阶段完成! 最佳Gather Loss: {best_gather_loss:.4f}")
        logger.info(f"INP提取器权重保存至: {save_dir}/stage1_best.pth")
    
    # ==================== 第二阶段：分类器训练 ====================
    
    if args.stage in ['stage2', 'all']:
        logger.info("=" * 80)
        logger.info("第二阶段：分类器训练")
        logger.info("=" * 80)
        
        if args.stage == 'stage2' and args.stage1_checkpoint:
            checkpoint_path = args.stage1_checkpoint
        elif args.stage == 'all':
            checkpoint_path = os.path.join(save_dir, 'stage1_best.pth')
        else:
            raise ValueError("第二阶段训练需要指定第一阶段检查点路径")
        
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"找不到第一阶段检查点: {checkpoint_path}")
        
        inp_extractor = INP_Extractor(
            dinov2_model_name=args.dinov2_model,
            num_inp=args.num_inp,
            target_layers=args.target_layers,
            freeze_backbone=False,
            unfrozen_blocks=-1 # 在第二阶段默认微调所有块
        ).to(device)
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        inp_extractor.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"✅ 加载第一阶段权重: {checkpoint_path}")
        
        classifier = INP_Classifier(
            inp_extractor=inp_extractor,
            num_classes=num_classes
        ).to(device)

        stage2_optimizer = optim.AdamW([
            {'params': classifier.inp_extractor.parameters(), 'lr': args.stage2_lr / 10},
            {'params': classifier.classifier.parameters(), 'lr': args.stage2_lr}
        ], weight_decay=args.weight_decay)

        logger.info(f"第二阶段优化器已配置差分学习率。")
        
        stage2_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            stage2_optimizer, T_max=args.stage2_epochs
        )
        
        criterion = nn.CrossEntropyLoss()
        
        classifier_params = sum(p.numel() for p in classifier.classifier.parameters())
        logger.info(f"第二阶段可训练参数 (分类器): {classifier_params:,}")
        
        best_accuracy = 0.0
        
        for epoch in range(args.stage2_epochs):
            # [修改] 调用更新后的训练函数，传入拓扑损失权重
            train_loss, train_accuracy = train_classifier(
                classifier, train_loader, criterion, stage2_optimizer, device,
                epoch, logger, writer, lambda_topo=args.stage2_lambda_topo
            )
            
            stage2_scheduler.step()
            current_lr = stage2_optimizer.param_groups[0]['lr']
            
            writer.add_scalar('Stage2/Loss_Epoch', train_loss, epoch)
            writer.add_scalar('Stage2/Accuracy_Epoch', train_accuracy, epoch)
            writer.add_scalar('Stage2/Learning_Rate', current_lr, epoch)
            
            if epoch % args.eval_interval == 0 or epoch == args.stage2_epochs - 1:
                # [修改] 只在一个测试集上评估和保存
                main_test_loader_name = list(test_loaders.keys())[0]
                main_test_loader = test_loaders[main_test_loader_name]
                test_loss, test_accuracy = evaluate_model(classifier, main_test_loader, device, logger)
                
                writer.add_scalar(f'Test_{main_test_loader_name}/Loss', test_loss, epoch)
                writer.add_scalar(f'Test_{main_test_loader_name}/Accuracy', test_accuracy, epoch)
                
                if test_accuracy > best_accuracy:
                    best_accuracy = test_accuracy
                    checkpoint_path = os.path.join(save_dir, f'stage2_best.pth')
                    save_checkpoint(classifier, stage2_optimizer, epoch, best_accuracy, checkpoint_path, is_best=True)
                    logger.info(f"✅ 保存第二阶段最佳模型 (Accuracy: {best_accuracy:.4f})")
            
            if epoch % 10 == 0:
                checkpoint_path = os.path.join(save_dir, f'stage2_epoch_{epoch}.pth')
                save_checkpoint(classifier, stage2_optimizer, epoch, best_accuracy, checkpoint_path)
        
        final_path = os.path.join(save_dir, 'stage2_final.pth')
        save_checkpoint(classifier, stage2_optimizer, args.stage2_epochs-1, best_accuracy, final_path)
        
        logger.info(f"第二阶段完成! 最佳准确率: {best_accuracy:.4f}")
        logger.info(f"分类器权重保存至: {save_dir}/stage2_best.pth")
    
    logger.info("=" * 80)
    logger.info("🎉 INP-Former纹理分类训练完成!")
    if args.stage in ['stage1', 'all']:
        logger.info(f"第一阶段 (INP提取器): 最佳Gather Loss = {best_gather_loss:.4f}")
    if args.stage in ['stage2', 'all']:
        logger.info(f"第二阶段 (分类器): 最佳准确率 = {best_accuracy:.4f}")
    logger.info(f"模型保存至: {save_dir}")
    logger.info("=" * 80)
    
    writer.close()


if __name__ == '__main__':
    main()