import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Iterable, Optional
from timm.data import Mixup
from sklearn.metrics import (
    accuracy_score, roc_auc_score, f1_score, average_precision_score,
    cohen_kappa_score
)
from pycm import ConfusionMatrix
import util.misc as misc
from util.misc import log_gradients_in_tensorboard
import util.lr_sched as lr_sched
from util.losses import MultiClassFocalLoss
from collections import defaultdict
import seaborn as sns
from pathlib import Path

# ------------------------- KL 稀疏正则 -------------------------
def kl_sparsity_loss(gates: torch.Tensor,
                     rho: float = 0.3,
                     eps: float = 1e-6) -> torch.Tensor:
    """
    KL(ρ‖ĝ)，鼓励 batch 均值 ĝ 收敛到目标稀疏率 ρ
    gates: 任意形状，值域 (0,1)
    """
    g_mean = torch.clamp(gates.mean(), eps, 1 - eps)
    kl = rho * torch.log(rho / g_mean) + (1 - rho) * torch.log((1 - rho) / (1 - g_mean))
    return kl
# --------------------------------------------------------------

def train_one_epoch(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    data_loader: Iterable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    loss_scaler,
    max_norm: float = 0,
    mixup_fn: Optional[Mixup] = None,
    log_writer=None,
    args=None
):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    
    print_freq, accum_iter = 20, args.accum_iter
    optimizer.zero_grad()
    
    criterion_cls = MultiClassFocalLoss(gamma=2.0) if args.use_focal_loss else criterion

    rho_target = args.sparsity_target          # e.g. 0.3
    λ_kl       = args.sparsity_lambda          # e.g. 0.01

    if log_writer:
        print(f'log_dir: {log_writer.log_dir}')
    
    all_routing_data_for_analysis = []
    all_targets_for_analysis = []
    for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
        
        samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        if mixup_fn:
            samples, targets = mixup_fn(samples, targets)
        
        # 与单领域设置保持一致，使用混合精度训练
        # recursive_mode is the flag for our iterative model
        if True:
            with torch.cuda.amp.autocast():
                class_logits, routing_info = model(samples)
                loss_cls = criterion_cls(class_logits, targets)
                loss = loss_cls
                loss_router = torch.tensor(0.0)
                
                # 汇总所有 gate
                if routing_info and λ_kl > 0: 
                    all_gates = torch.cat([g for (g, _) in routing_info], dim=1)
                    loss_router = kl_sparsity_loss(all_gates, rho=rho_target)
                    loss = loss + λ_kl * loss_router
            
            # Collect routing data for analysis (periodically to save memory)
            if data_iter_step % 10 == 0:
                all_routing_data_for_analysis.append(routing_info)
                all_targets_for_analysis.append(targets.clone())
                
                # Log gate statistics to TensorBoard
                global_step = epoch * len(data_loader) + data_iter_step
                if routing_info and log_writer:
                    # routing_info is a list of (p_soft, step_mask) tuples
                    all_gate_probs = [data[0] for data in routing_info if len(data) > 0]
                    if all_gate_probs:
                        gate_vals = torch.cat([p.flatten() for p in all_gate_probs], dim=0)
                        log_writer.add_scalar('8_Gates/mean', gate_vals.mean().item(), global_step)
                        log_writer.add_scalar('8_Gates/std', gate_vals.std().item(), global_step)
                        log_writer.add_histogram('8_Gates/distribution', gate_vals, global_step)
        else:
            outputs = model(samples)
            loss = criterion(outputs, targets)
            loss_router = torch.tensor(0.0)  # 确保变量存在

        loss_value = loss.item()
        loss /= accum_iter
        
        # 根据是否使用混合精度选择不同的反向传播方式
        if loss_scaler is not None:
            loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
                        update_grad=(data_iter_step + 1) % accum_iter == 0)
        else:
            # 不使用混合精度的反向传播
            loss.backward()
            if (data_iter_step + 1) % accum_iter == 0:
                if max_norm is not None and max_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                optimizer.step()
                optimizer.zero_grad()

        if (data_iter_step + 1) % accum_iter == 0:
            if log_writer is not None:
                global_step = epoch * len(data_loader) + data_iter_step
                model_without_ddp = model.module if hasattr(model, 'module') else model
                log_gradients_in_tensorboard(model_without_ddp, log_writer, global_step)
        
        torch.cuda.synchronize()
        metric_logger.update(loss=loss_value)

        min_lr, max_lr = 10., 0.
        for group in optimizer.param_groups:
            min_lr = min(min_lr, group["lr"])
            max_lr = max(max_lr, group["lr"])
        metric_logger.update(lr=max_lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('4_Loss/train_total', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('4_Loss/sparsity_loss_kl', loss_router.item(), epoch_1000x)
            log_writer.add_scalar('5_Training/learning_rate', max_lr, epoch_1000x)
            
    # At the end of the epoch, perform detailed routing analysis
    if log_writer and all_routing_data_for_analysis:
        log_detailed_routing_analysis(
            all_routing_data_for_analysis,
            all_targets_for_analysis,
            args.nb_classes,
            log_writer,
            epoch,
            'train',
            save_plots=True,
            output_dir=args.output_dir,
            args=args
        )

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
    criterion = nn.CrossEntropyLoss()
    metric_logger = misc.MetricLogger(delimiter="  ")
    os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)

    criterion_cls = MultiClassFocalLoss(gamma=2.0) if args.use_focal_loss else criterion

    model.eval()
    
    all_routing_data_list, all_targets_list = [], []
    true_labels, pred_labels, pred_softmax = [], [], []
    all_outputs_for_debug = []
    for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
        images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
        
        # 与单领域设置保持一致，使用混合精度
        if True:
            with torch.cuda.amp.autocast():
                output, routing_info = model(images)
                loss = criterion_cls(output, target)
            all_routing_data_list.append(routing_info)
            all_targets_list.append(target)
        else:
            output = model(images)
            loss = criterion(output, target)

        output_ = nn.Softmax(dim=1)(output)
        output_label = output_.argmax(dim=1)
        # ==================== DEBUG CODE START ====================
        if len(all_outputs_for_debug) < 20: # 打印前20个样本的预测细节
            for i in range(len(target)):
                true_label = target[i].item()
                pred_label = output_label[i].item()
                softmax_probs = output_[i].detach().cpu().numpy()
                print(f"True: {true_label}, Predicted: {pred_label}, Probs: [{', '.join(f'{p:.2f}' for p in softmax_probs)}]")
                all_outputs_for_debug.append({
                    'true': true_label,
                    'pred': pred_label,
                    'probs': softmax_probs
                })
        # ===================== DEBUG CODE END =====================
        metric_logger.update(loss=loss.item())

        true_labels.extend(target.cpu().numpy())
        pred_labels.extend(output_label.detach().cpu().numpy())
        pred_softmax.extend(output_.detach().cpu().numpy())
    
    # Calculate evaluation metrics
    true_labels = np.array(true_labels)
    pred_labels = np.array(pred_labels)
    pred_softmax = np.array(pred_softmax)

    accuracy = accuracy_score(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels, average='macro', zero_division=0)
    true_onehot = F.one_hot(torch.from_numpy(true_labels), num_classes=num_class).numpy()
    roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
    average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
    kappa = cohen_kappa_score(true_labels, pred_labels)
    score = (f1 + roc_auc + kappa) / 3

    if log_writer:
        log_writer.add_scalar(f'3_Performance/accuracy', accuracy, epoch)
        log_writer.add_scalar(f'3_Performance/f1_score', f1, epoch)
        log_writer.add_scalar(f'3_Performance/roc_auc', roc_auc, epoch)
        log_writer.add_scalar(f'3_Performance/kappa', kappa, epoch)
        log_writer.add_scalar(f'3_Performance/score', score, epoch)
        log_writer.add_scalar(f'4_Loss/val_total', metric_logger.meters["loss"].global_avg, epoch)
        
        if all_targets_list:
            log_detailed_routing_analysis(
                all_routing_data_list, all_targets_list, num_class,
                log_writer, epoch, prefix=mode,
                save_plots=True,
                output_dir=args.output_dir,
                args=args
            )

    print(f'val loss: {metric_logger.meters["loss"].global_avg:.4f}')
    print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
    
    metric_logger.synchronize_between_processes()
    
    results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
    with open(results_path, 'a', newline='', encoding='utf8') as cfa:
        wf = csv.writer(cfa)
        if cfa.tell() == 0:
            wf.writerow(['epoch', 'val_loss', 'accuracy', 'f1', 'roc_auc', 'kappa', 'score'])
        wf.writerow([epoch, metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, kappa, score])
    
    if mode == 'test':
        cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
        cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
        plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
    
    # 修复：将计算出的指标添加到返回字典中
    result_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    result_dict.update({
        'accuracy': accuracy,
        'f1': f1,
        'roc_auc': roc_auc,
        'average_precision': average_precision,
        'kappa': kappa,
        'score': score
    })
    
    return result_dict, score

@torch.no_grad()
def log_detailed_routing_analysis(
    all_routing_data_list: list,
    all_targets_list: list,
    num_classes: int,
    log_writer,
    epoch: int,
    prefix: str,
    save_plots: bool = True,
    output_dir: str = None,
    args=None
):
    """
    重构后的路由分析函数，适配新的模型输出。
    1. 正确解析扁平化的路由数据列表。
    2. 移除不再需要的 expand_mask_to_197 函数。
    3. 简化数据处理逻辑，提高代码清晰度。
    4. 支持grid search模式：在grid search模式下跳过本地图片生成，但保留tensorboard记录。
    """
    if not log_writer or not all_routing_data_list:
        return

    device = all_targets_list[0].device if all_targets_list else torch.device('cpu')
    
    # 检查是否处于grid search模式
    is_grid_search_mode = getattr(args, 'grid_search_mode', False)
    
    # 只有在非grid search模式下才创建本地图片目录和设置save_plots
    if save_plots and output_dir and not is_grid_search_mode:
        plots_dir = Path(output_dir) / args.task / "routing_analysis_plots" / f"epoch_{epoch:03d}"
        plots_dir.mkdir(parents=True, exist_ok=True)
        print(f"Routing analysis plots will be saved to: {plots_dir}")
        should_save_local_plots = True
    else:
        plots_dir = None
        should_save_local_plots = False
        if is_grid_search_mode:
            print(f"Grid search mode: Skipping local plot generation, keeping TensorBoard logs only")
    
    # --- 数据收集容器 ---
    # block_iteration_ratios[block_idx][step_idx] = [ratios_list]
    block_iteration_ratios = defaultdict(lambda: defaultdict(list))
    # class_block_ratios[class_idx][block_idx][step_idx] = [ratios_list]
    class_block_ratios = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    # mask_data[block_idx][step_idx] = [mask_tensors_list]
    mask_data = defaultdict(lambda: defaultdict(list))
    
    class_sample_counts = torch.zeros(num_classes, device=device)
    
    print(f"\n=== Starting Routing Analysis ({prefix.upper()}, Epoch {epoch}) ===")
    if is_grid_search_mode:
        print("Mode: Grid Search (TensorBoard only)")
    
    # --- 数据处理循环 ---
    # 逐批次(batch)处理收集到的数据
    for batch_idx in range(len(all_targets_list)):
        # routing_info_for_batch 是一个扁平化的列表，包含了所有block和step的路由元组
        routing_info_for_batch = all_routing_data_list[batch_idx]
        batch_targets = all_targets_list[batch_idx]
        
        if not routing_info_for_batch:
            continue
        
        # 统计每个类别的样本数
        for class_idx in range(num_classes):
            class_sample_counts[class_idx] += (batch_targets == class_idx).sum()
        
        # 【核心修正】正确解析扁平化的路由数据
        # 模型有 num_recurrent_blocks 个迭代块，每个块执行 num_recursion_steps 次
        steps_per_block = args.num_recursion_steps
        for i, routing_data in enumerate(routing_info_for_batch):
            block_idx = i // steps_per_block
            step_idx = i % steps_per_block
            
            # 【核心修正】简化数据解包，step_mask已经是(B, N)的完整掩码
            p_soft, step_mask = routing_data
            
            mask_data[block_idx][step_idx].append(step_mask)
            
            total_tokens = step_mask.shape[1]
            batch_size = step_mask.shape[0]
            
            for sample_idx in range(batch_size):
                passed_tokens = step_mask[sample_idx].sum().item()
                ratio = passed_tokens / total_tokens if total_tokens > 0 else 0
                block_iteration_ratios[block_idx][step_idx].append(ratio)
                
                sample_class = batch_targets[sample_idx].item()
                class_block_ratios[sample_class][block_idx][step_idx].append(ratio)
    
    # --- TensorBoard记录（始终执行） ---
    # 记录整体token传递比例到TensorBoard
    if block_iteration_ratios:
        num_blocks = len(block_iteration_ratios)
        for block_idx in range(num_blocks):
            iterations = sorted(block_iteration_ratios[block_idx].keys())
            mean_ratios = [np.mean(block_iteration_ratios[block_idx][step]) for step in iterations]
            
            for step_idx in iterations:
                log_writer.add_scalar(f'2_Token_Ratios/overall/block{block_idx}_step{step_idx}_{prefix}', mean_ratios[step_idx], epoch)
    
    # 记录按类别的token传递比例到TensorBoard
    if class_block_ratios:
        for class_idx in range(num_classes):
            if class_idx not in class_block_ratios or class_sample_counts[class_idx] == 0:
                continue
            
            class_data = class_block_ratios[class_idx]
            num_blocks = len(class_data)
            
            for block_idx in range(num_blocks):
                iterations = sorted(class_data[block_idx].keys())
                ratios = [np.array(class_data[block_idx][step]) for step in iterations]
                mean_ratios = [r.mean() if len(r)>0 else 0 for r in ratios]

                for step_idx in iterations:
                    log_writer.add_scalar(f'2_Token_Ratios/class{class_idx}/block{block_idx}_step{step_idx}_{prefix}', mean_ratios[step_idx], epoch)
    
    # 记录热力图统计信息到TensorBoard
    if mask_data:
        for block_idx in mask_data.keys():
            for step_idx in mask_data[block_idx].keys():
                all_masks = torch.cat(mask_data[block_idx][step_idx], dim=0)
                
                patch_masks = all_masks[:, 1:]
                if patch_masks.shape[0] == 0: continue

                avg_mask = patch_masks.float().mean(dim=0)
                
                if avg_mask.shape[0] == 196:
                    heatmap = avg_mask.reshape(14, 14).cpu().numpy()
                    
                    # 始终记录热力图到TensorBoard
                    log_writer.add_image(f'9_Heatmaps/{prefix}/block{block_idx}_step{step_idx}', heatmap, epoch, dataformats='HW')
                    log_writer.add_scalar(f'9_Heatmaps/stats/block{block_idx}_step{step_idx}_{prefix}_mean', heatmap.mean(), epoch)
                    log_writer.add_scalar(f'9_Heatmaps/stats/block{block_idx}_step{step_idx}_{prefix}_std', heatmap.std(), epoch)
    
    # --- 本地图片生成（仅在非grid search模式下执行） ---
    if should_save_local_plots:
        plt.style.use('default')
        sns.set_palette("husl")
        
        # 1. Token传递比例折线图 (整体)
        if block_iteration_ratios:
            num_blocks = len(block_iteration_ratios)
            if num_blocks > 0:
                fig_rows = 1 if num_blocks <= 2 else 2
                fig_cols = num_blocks if num_blocks <= 2 else (num_blocks + 1) // 2
                fig, axes = plt.subplots(fig_rows, fig_cols, figsize=(6 * fig_cols, 5 * fig_rows), squeeze=False)
                axes = axes.flatten()
                
                for block_idx in range(num_blocks):
                    ax = axes[block_idx]
                    iterations = sorted(block_iteration_ratios[block_idx].keys())
                    mean_ratios = [np.mean(block_iteration_ratios[block_idx][step]) for step in iterations]
                    std_ratios = [np.std(block_iteration_ratios[block_idx][step]) for step in iterations]
                    
                    ax.plot(iterations, mean_ratios, 'o-', label=f'Block {block_idx}')
                    ax.fill_between(iterations, 
                                   np.subtract(mean_ratios, std_ratios),
                                   np.add(mean_ratios, std_ratios), 
                                   alpha=0.3)
                    
                    ax.set_xlabel('Iteration Step'); ax.set_ylabel('Token Passing Ratio')
                    ax.set_title(f'Block {block_idx} - Overall Token Ratio'); ax.grid(True, alpha=0.3)
                    ax.set_ylim(0, 1)
                    for i, ratio in enumerate(mean_ratios):
                        ax.text(iterations[i], ratio + 0.05, f'{ratio:.1%}', ha='center')
                
                plt.tight_layout()
                plt.savefig(plots_dir / f'{prefix}_token_passing_ratios.png', dpi=300)
                plt.close()
        
        # 2. Token传递比例折线图 (按类别)
        if class_block_ratios:
            for class_idx in range(num_classes):
                if class_idx not in class_block_ratios or class_sample_counts[class_idx] == 0:
                    continue
                
                class_data = class_block_ratios[class_idx]
                num_blocks = len(class_data)
                if num_blocks == 0: continue

                fig_rows = 1 if num_blocks <= 2 else 2
                fig_cols = num_blocks if num_blocks <= 2 else (num_blocks + 1) // 2
                fig, axes = plt.subplots(fig_rows, fig_cols, figsize=(6 * fig_cols, 5 * fig_rows), squeeze=False)
                axes = axes.flatten()
                
                for block_idx in range(num_blocks):
                    ax = axes[block_idx]
                    iterations = sorted(class_data[block_idx].keys())
                    ratios = [np.array(class_data[block_idx][step]) for step in iterations]
                    mean_ratios = [r.mean() if len(r)>0 else 0 for r in ratios]
                    std_ratios = [r.std() if len(r)>0 else 0 for r in ratios]
                    
                    ax.plot(iterations, mean_ratios, 'o-')
                    ax.fill_between(iterations, 
                                   np.subtract(mean_ratios, std_ratios),
                                   np.add(mean_ratios, std_ratios), alpha=0.3)

                    ax.set_xlabel('Iteration Step'); ax.set_ylabel('Token Passing Ratio')
                    ax.set_title(f'Class {class_idx} - Block {block_idx} Ratio'); ax.grid(True, alpha=0.3)
                    ax.set_ylim(0, 1)
                    for i, ratio in enumerate(mean_ratios):
                        ax.text(iterations[i], ratio + 0.05, f'{ratio:.1%}', ha='center')

                plt.tight_layout()
                plt.savefig(plots_dir / f'{prefix}_class{class_idx}_token_ratios.png', dpi=300)
                plt.close()

        # 3. 生成14x14热力图
        if mask_data:
            for block_idx in mask_data.keys():
                for step_idx in mask_data[block_idx].keys():
                    all_masks = torch.cat(mask_data[block_idx][step_idx], dim=0)
                    
                    patch_masks = all_masks[:, 1:]
                    if patch_masks.shape[0] == 0: continue

                    avg_mask = patch_masks.float().mean(dim=0)
                    
                    if avg_mask.shape[0] == 196:
                        heatmap = avg_mask.reshape(14, 14).cpu().numpy()
                        
                        fig, ax = plt.subplots(figsize=(8, 8))
                        sns.heatmap(heatmap, ax=ax, cmap='viridis', vmin=0, vmax=1, square=True,
                                    cbar_kws={'label': 'Selection Probability'})
                        ax.set_title(f'{prefix.upper()} - Block{block_idx} Step{step_idx} Heatmap (Epoch {epoch})')
                        
                        plt.tight_layout()
                        plt.savefig(plots_dir / f'{prefix}_heatmap_block{block_idx}_step{step_idx}.png', dpi=300)
                        plt.close()

    print(f"Routing analysis complete - Epoch {epoch}")
    if should_save_local_plots:
        print(f"Analysis plots saved to: {plots_dir}")
    elif is_grid_search_mode:
        print("Grid search mode: Local plots skipped, TensorBoard logs recorded")
    print("=" * 60)