
import os
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.backends.cudnn as cudnn
from torch.cuda.amp import GradScaler, autocast
import random
from tqdm import tqdm
import warnings
from functools import partial
from copy import deepcopy
import torch.nn.functional as F
from utils import extract_coordinates, euclidean_distance, save_net_opt, calculate_aop, load_net
from model_factory import get_model, replace_bn_with_in
from visualization import visualize_keypoints
from semi_dataset import LabeledDataset, UnlabeledDataset

warnings.filterwarnings('ignore')

def semi_collate_fn(batch):
    transposed_batch = list(zip(*batch))
    processed_batch = []
    for item_list in transposed_batch:
        if item_list[0] is not None:
            processed_batch.append(torch.utils.data.default_collate(item_list))
        else:
            processed_batch.append(None)
    return tuple(processed_batch)


MODEL_TYPE = 'unet'
ENCODER_NAME = 'efficientnet-b4'

USE_VIT_BOTTLENECK = True  # <-- 设置为 True 来启用ViT瓶颈
VIT_PRETRAINED_PATH = '/root/autodl-tmp/code/mae_pretrained_new/mae_vit_tiny_encoder_epoch_200.pth'

USE_DROPOUT = True
DROPOUT_PROBABILITY = 0.1
USE_KAN = True
REPLACE_BN = True
USE_HIERARCHICAL_HEAD = True
PRETRAINED_PATH = '' 
# --- 半监督学习参数 ---
SSL_LAMBDA_FINAL = 0
SSL_RAMP_UP_EPOCHS = 30

MIN_CONFIDENCE_THRESHOLD = 0.5  # 训练初期阈值
MAX_CONFIDENCE_THRESHOLD = 0.9  # 训练后期阈值

EMA_DECAY_INITIAL = 0.99    # 初始衰减率
EMA_DECAY_FINAL = 0.999     # 最终衰减率
EMA_DECAY_RAMP_EPOCHS = 30  # 衰减率 ramp-up 周期

FH1_LOSS_WEIGHT = 1.0  # 这意味着FH1的损失将被放大

LOSS_ALPHA = 0.1
MIN_RETAIN_RATIO = 0.3  # 至少保留30%的样本进行无监督学习

def get_ssl_lambda(current_epoch, ramp_up_epochs, final_lambda):
    if ramp_up_epochs == 0:
        return final_lambda
    current_epoch = np.clip(current_epoch, 0, ramp_up_epochs)
    phase = 1.0 - current_epoch / ramp_up_epochs
    return final_lambda * float(np.exp(-5.0 * phase * phase))

def get_confidence_threshold(current_epoch, total_epochs):
    """随训练进行线性提高置信度阈值"""
    if total_epochs == 0:
        return MAX_CONFIDENCE_THRESHOLD
    progress = np.clip(current_epoch / total_epochs, 0, 1)
    return MIN_CONFIDENCE_THRESHOLD + (MAX_CONFIDENCE_THRESHOLD - MIN_CONFIDENCE_THRESHOLD) * progress

def get_ema_decay(current_epoch, ramp_epochs):
    """训练初期使用较小的衰减率，让教师模型更快适应学生模型的更新"""
    if ramp_epochs == 0:
        return EMA_DECAY_FINAL
    progress = np.clip(current_epoch / ramp_epochs, 0, 1)
    return EMA_DECAY_INITIAL + (EMA_DECAY_FINAL - EMA_DECAY_INITIAL) * progress

@torch.no_grad()
def update_teacher_model(student_model, teacher_model, decay):
    """使用指数移动平均 (EMA) 更新教师模型的权重"""
    student_params = student_model._orig_mod.parameters() if hasattr(student_model, '_orig_mod') else student_model.parameters()
    teacher_params = teacher_model.parameters()

    for student_p, teacher_p in zip(student_params, teacher_params):
        teacher_p.data.mul_(decay).add_(student_p.data, alpha=1 - decay)


# --- 主训练函数 ---
def mean_teacher_train(args, save_path):

    # --- 环境设置 ---
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- 数据加载 ---
    print("Loading labeled and unlabeled data...")
    labeled_df = pd.read_csv(os.path.join(args.root_dataset, 'train_set.csv'))
    unlabeled_df = pd.read_csv(os.path.join(args.root_dataset, 'unlabeled_set.csv'))
    val_df = pd.read_csv(os.path.join(args.root_dataset, 'val_set.csv'))
    
    labeled_dataset = LabeledDataset(dataframe=labeled_df, image_size=args.img_size, sigma=args.sigma, train=True)
    unlabeled_dataset = UnlabeledDataset(dataframe=unlabeled_df, image_size=args.img_size)
    val_dataset = LabeledDataset(dataframe=val_df, image_size=args.img_size, sigma=args.sigma, train=False)
    
    labeled_loader = DataLoader(labeled_dataset, batch_size=args.labeled_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=semi_collate_fn)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=args.unlabeled_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=semi_collate_fn)
    
    print(f"Labeled set: {len(labeled_dataset)}, Unlabeled set: {len(unlabeled_dataset)}, Validation set: {len(val_dataset)}")

    # --- 模型、优化器、损失函数 ---
    print("Creating student and teacher models...")
    # student_model = get_model(encoder_name=ENCODER_NAME, use_kan=USE_KAN)
    student_model = get_model(
        encoder_name=ENCODER_NAME, 
        use_kan=USE_KAN,
        use_hierarchical_head=USE_HIERARCHICAL_HEAD,
        use_vit_bottleneck=USE_VIT_BOTTLENECK,
        vit_pretrained_path=VIT_PRETRAINED_PATH,
        use_dropout=USE_DROPOUT,
        dropout_p=DROPOUT_PROBABILITY
    )
    if REPLACE_BN:
        replace_bn_with_in(student_model)
    student_model.to(device)
    
    teacher_model = deepcopy(student_model)
    
    print("Compiling models...")
    student_model = torch.compile(student_model)
    teacher_model = torch.compile(teacher_model)

    if PRETRAINED_PATH and os.path.exists(PRETRAINED_PATH):
        print(f"Loading weights from pre-trained model: {PRETRAINED_PATH}")
        load_net(student_model, PRETRAINED_PATH, map_location=device)
        load_net(teacher_model, PRETRAINED_PATH, map_location=device)
    else:
        print("No pre-trained model provided. Training from scratch.")

    for param in teacher_model.parameters():
        param.requires_grad = False
    teacher_model.eval()

    optimizer = optim.AdamW(student_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs, eta_min=1e-6)
    scaler = GradScaler()
    
    supervised_loss_fn = nn.MSELoss()
    unsupervised_loss_fn_per_element = nn.MSELoss(reduction='none')
    
    best_score = float('inf')
    patience_counter = 0

    # --- 训练循环 ---
    for epoch in range(1, args.epochs + 1):
        student_model.train()
        total_loss, sup_loss, unsup_loss_val = 0, 0, 0
        
        # 获取当前周期的动态参数
        current_ssl_lambda = get_ssl_lambda(epoch, SSL_RAMP_UP_EPOCHS, SSL_LAMBDA_FINAL)
        current_confidence = get_confidence_threshold(epoch, args.epochs)
        current_ema_decay = get_ema_decay(epoch, EMA_DECAY_RAMP_EPOCHS)
        
        unlabeled_iter = iter(unlabeled_loader)
        train_iter = tqdm(labeled_loader, desc=f"Train {epoch}/{args.epochs}", leave=False)
        
        if epoch <= args.warmup_epochs:
            current_lr = args.lr * (epoch / args.warmup_epochs)
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

        for labeled_imgs, labeled_targets, _, _ in train_iter:
            try:
                unlabeled_weak, unlabeled_strong = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                unlabeled_weak, unlabeled_strong = next(unlabeled_iter)

            labeled_imgs, labeled_targets = labeled_imgs.to(device), labeled_targets.to(device)
            unlabeled_weak, unlabeled_strong = unlabeled_weak.to(device), unlabeled_strong.to(device)
            
            optimizer.zero_grad(set_to_none=True)
            
            with autocast():
                # --- 创建损失权重掩码 ---
                # 放大 FH1 通道的损失
                loss_weights = torch.tensor([1.0, 1.0, FH1_LOSS_WEIGHT], device=device).view(1, 3, 1, 1)

                # 1. 监督损失 (应用加权)
                pred_labeled = student_model(labeled_imgs.float())
                
                # a) 计算逐元素的原始MSE损失
                loss_sup_per_element = F.mse_loss(pred_labeled, labeled_targets, reduction='none')
                
                # b) 将逐元素损失乘以我们的权重掩码
                weighted_loss_sup_per_element = loss_sup_per_element * loss_weights
                
                # c) 对加权后的损失求平均，得到最终的监督损失值
                loss_sup = torch.mean(weighted_loss_sup_per_element)

                # 2. 无监督损失 (应用加权)
                with torch.no_grad():
                    pseudo_targets = teacher_model(unlabeled_weak.float())

                # --- 置信度过滤 ---
                batch_size_unlabeled = pseudo_targets.size(0)
                max_confidence_per_sample = torch.max(pseudo_targets.reshape(batch_size_unlabeled, -1), dim=1)[0]
                
                mask = (max_confidence_per_sample >= current_confidence).float()
                retained_count = mask.sum().item()
                
                if batch_size_unlabeled > 0 and retained_count / batch_size_unlabeled < MIN_RETAIN_RATIO:
                    k = max(int(batch_size_unlabeled * MIN_RETAIN_RATIO), 1)
                    _, topk_indices = torch.topk(max_confidence_per_sample, k)
                    mask = torch.zeros_like(max_confidence_per_sample)
                    mask[topk_indices] = 1.0

                pred_strong = student_model(unlabeled_strong.float())
                
                # a) 计算逐元素的原始MSE损失
                loss_unsup_per_element = F.mse_loss(pred_strong, pseudo_targets, reduction='none')
                
                # b) 将逐元素损失乘以我们的权重掩码
                weighted_loss_unsup_per_element = loss_unsup_per_element * loss_weights
                
                # c) 计算每个样本的加权平均损失
                loss_unsup_per_sample = torch.mean(weighted_loss_unsup_per_element.reshape(batch_size_unlabeled, -1), dim=1)
                
                # d) 应用置信度掩码并计算最终的无监督损失值
                if mask.sum() > 0:
                    # 只对那些通过了置信度过滤的样本计算损失
                    loss_unsup = (mask * loss_unsup_per_sample).sum() / mask.sum()
                else:
                    loss_unsup = torch.tensor(0.0).to(device)

                # 3. 组合总损失
                total_loss_batch = loss_sup + current_ssl_lambda * loss_unsup

            scaler.scale(total_loss_batch).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # 使用动态衰减率更新教师模型
            update_teacher_model(student_model, teacher_model, current_ema_decay)

            total_loss += total_loss_batch.item()
            sup_loss += loss_sup.item()
            unsup_loss_val += loss_unsup.item()
            
            train_iter.set_postfix({
                "t_l": f"{total_loss_batch.item():.6f}", "s_l": f"{loss_sup.item():.6f}", 
                "u_l": f"{loss_unsup.item():.6f}", "m_ratio": f"{mask.mean().item():.2f}",
                "s_lmd": f"{current_ssl_lambda:.2f}", "conf": f"{current_confidence:.2f}"
            })

        if epoch > args.warmup_epochs:
            scheduler.step()

        # --- 验证循环 ---
        teacher_model.eval()
        total_mre, total_apd = 0.0, 0.0
        
        avg_train_loss = total_loss / len(labeled_loader)
        avg_sup_loss = sup_loss / len(labeled_loader)
        avg_unsup_loss = unsup_loss_val / len(labeled_loader)

        val_iter = tqdm(val_loader, desc=f"Eval {epoch}/{args.epochs}", leave=False)
        with torch.no_grad():
            for imgs, _, landmarks_gt_normalized, _ in val_iter:
                imgs, landmarks_gt_normalized = imgs.to(device), landmarks_gt_normalized.to(device)
                outputs = teacher_model(imgs.float())
                pred_coords_normalized = extract_coordinates(outputs)
                
                total_mre += torch.mean(euclidean_distance(pred_coords_normalized, landmarks_gt_normalized)).item()
                
                for i in range(imgs.size(0)):
                    pred_points = pred_coords_normalized[i].view(3, 2)
                    gt_points = landmarks_gt_normalized[i].view(3, 2)
                    pred_aop = calculate_aop(pred_points[0], pred_points[1], pred_points[2])
                    gt_aop = calculate_aop(gt_points[0], gt_points[1], gt_points[2])
                    total_apd += abs(pred_aop - gt_aop)

        avg_mre = total_mre / len(val_loader)
        avg_apd = total_apd / len(val_dataset)
        composite_score = 1000 * avg_mre
    
        print(
            f"Epoch [{epoch}/{args.epochs}], "
            f"Train Loss: {avg_train_loss:.6f} (Sup: {avg_sup_loss:.6f}, Unsup: {avg_unsup_loss:.6f}), "
            f"Val MRE: {avg_mre:.6f}, Val APD: {avg_apd:.4f}, Score: {composite_score:.4f}, "
            f"EMA Decay: {current_ema_decay:.4f}"
        )
        
        if (epoch) % 2 == 0:
            visualize_keypoints(teacher_model, val_loader, device, epoch, save_path=save_path)
        
        if composite_score < best_score:
            best_score = composite_score
            patience_counter = 0
            print(f"    -> New best model found! Score: {best_score:.4f}. Saving teacher model...")
            torch.save({'net': teacher_model.state_dict()}, os.path.join(save_path, 'best_model.pth'))
            # torch.save({'net': student_model.state_dict()}, os.path.join(save_path, 'best_model.pth'))
        else:
            patience_counter += 1
        
        if patience_counter >= args.patience:
            print(f"Early stopping at epoch {epoch}.")
            break
            
    return os.path.join(save_path, 'best_model.pth')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Mean Teacher Semi-Supervised Training")
    parser.add_argument('--root_dataset', type=str, default='./dataset', help='Dataset root directory')
    parser.add_argument('--save_dir', type=str, default='./Unet_cp', help='Directory to save results')
    parser.add_argument('--gpu', type=str, default="0", help='GPU ID')
    parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for DataLoader')
    parser.add_argument('--seed', type=int, default=2025, help='Random seed')
    
    parser.add_argument('--epochs', type=int, default=200, help='Max number of training epochs')
    parser.add_argument('--batch_size', type=int, default=24, help='Batch size for validation and inference')
    parser.add_argument('--labeled_batch_size', type=int, default=24, help='Batch size for labeled data')
    parser.add_argument('--unlabeled_batch_size', type=int, default=0, help='Batch size for unlabeled data')
    
    parser.add_argument('--lr', type=float, default=5e-4, help='Initial learning rate for fine-tuning')
    parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay to prevent overfitting')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of epochs for learning rate warm-up')
    parser.add_argument('--patience', type=int, default=30, help='Patience for early stopping')

    parser.add_argument('--img_size', type=int, default=512, help='Image size')
    parser.add_argument('--sigma', type=float, default=4.0, help='Sigma for heatmap generation')
    
    args = parser.parse_args()

    save_path = os.path.join(args.save_dir, f"{MODEL_TYPE}_{ENCODER_NAME}")
    os.makedirs(save_path, exist_ok=True)
    
    mean_teacher_train(args, save_path)