import argparse
import copy
import math
import pprint
import random
from typing import Tuple

import numpy as np
import torch
import torch_pruning
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from torch_pruning import ops
from torch_pruning.pruner.algorithms import BasePruner as MagnitudePruner
from torch_pruning.utils import count_ops_and_params  # <<< 新增：统计 MACs / Params

# 评测
from assp.core.eval import evaluate_downstream, evaluate_pretrain

# 剪枝
from assp.core.pruning import (
    PGRGroupImportance,
    RelPosManager,
    collect_bank,
    collect_target_and_ignored_layers,
)

# 补偿训练
from assp.core.train import compensate_downstream_adapter_only, compensate_unified_distillation, set_trainable_parts

# 数据
from assp.data.dataset_builder import build_dataloaders

# 模型
from assp.models.model_builder import build_teacher_student

# 实验系统（system v1）
from assp.utils.system import create_logger, setup_experiment_directory
from assp.utils.vis_runner import visualize_downstream_once


def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def _count_params(model: torch.nn.Module) -> Tuple[int, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


def _dims_from_encoder(encoder) -> Tuple[int, int]:
    any_attn = encoder.blocks[0].attn
    any_mlp = encoder.blocks[0].mlp
    head_dim = any_attn.q.out_features // any_attn.num_heads
    mlp_dim = any_mlp.lin1.out_features
    return head_dim, mlp_dim


def _fmt_down_metrics(metrics_dict: dict, metrics_list):
    iou_idx = metrics_list.index("iou") if "iou" in metrics_list else -1
    dice_idx = metrics_list.index("dice") if "dice" in metrics_list else -1

    parts = []
    if "final" in metrics_dict:  # box-only
        arr = metrics_dict["final"]
        biou = float(arr[iou_idx]) if iou_idx != -1 and arr.size > iou_idx else 0.0
        bdice = float(arr[dice_idx]) if dice_idx != -1 and arr.size > dice_idx else 0.0
        parts.append(f"BOX IoU/Dice={biou:.4f}/{bdice:.4f}")

    pt_keys = sorted([k for k in metrics_dict.keys() if isinstance(k, int)])
    dice_list = []
    for k in pt_keys:
        arr = metrics_dict[k]
        if dice_idx != -1 and arr.size > dice_idx:
            parts.append(f"P{k}_Dice={float(arr[dice_idx]):.4f}")
            dice_list.append(float(arr[dice_idx]))
    if dice_list:
        parts.append(f"DiceAvg@pts={float(np.mean(dice_list)):.4f}")
    return " | ".join(parts) if parts else "(no metrics)"


def _format_number(x: float) -> str:
    # 直观格式：K/M/G
    if x >= 1e9:
        return f"{x / 1e9:.2f} G"
    if x >= 1e6:
        return f"{x / 1e6:.2f} M"
    if x >= 1e3:
        return f"{x / 1e3:.2f} K"
    return f"{int(x)}"


def main():
    ap = argparse.ArgumentParser()

    # --- 实验路径 ---
    ap.add_argument("--exp_root", type=str, default="./runs")
    ap.add_argument("--exp_name", type=str, default="assp")
    ap.add_argument("--metrics", nargs="+", default=["iou", "dice"])
    # --- Model and Checkpoint Parameters ---
    ap.add_argument("--variant", type=str, default="vit_b", choices=["vit_b", "vit_l", "vit_h"])
    ap.add_argument("--sam_checkpoint", type=str, default=None)
    ap.add_argument("--sammed2d_checkpoint", type=str, default=None)

    # --- Dataset and Path Parameters ---
    ap.add_argument("--sa1b_train_root", type=str, required=True)
    ap.add_argument("--sa1b_val_root", type=str, required=True)
    ap.add_argument("--samed2d_root", type=str, required=True)

    # --- Training Configuration Parameters ---
    ap.add_argument("--sammed2d_image_size", type=int, default=256)
    ap.add_argument("--adapter", type=bool, default=True)
    ap.add_argument("--epochs_per_round_mode1", type=int, default=40)
    ap.add_argument("--epochs_per_round_mode3", type=int, default=40)
    ap.add_argument("--batch_size", type=int, default=4)
    ap.add_argument("--train_size", type=int, default=10000)
    ap.add_argument("--val_size", type=int, default=50)
    ap.add_argument("--lr", type=float, default=1e-4)
    ap.add_argument("--seed", type=int, default=42)
    # ap.add_argument("--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"])

    # --- Pruning Parameters ---
    ap.add_argument("--rank", type=int, default=16)
    ap.add_argument("--energy", type=float, default=None)
    ap.add_argument("--pruning_rounds", type=int, default=1)
    ap.add_argument("--pruning_ratio", type=float, default=0.5)

    # --- 一阶段补偿训练参数 ---
    ap.add_argument("--w_pre", type=float, default=1.0)
    ap.add_argument("--w_down", type=float, default=1.0)

    # --- 二阶段补偿训练参数 ---
    ap.add_argument("--point_num", type=int, default=1)
    ap.add_argument("--iter_point", type=int, default=7)
    ap.add_argument("--multimask", type=bool, default=True)

    args = ap.parse_args()

    # 轻量兼容：evaluate_* 里使用了 args.medsam_image_size
    if not hasattr(args, "medsam_image_size"):
        args.medsam_image_size = args.sammed2d_image_size

    set_seed(args.seed)
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, static_graph=False)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
    # accelerator = Accelerator()
    accelerator.even_batches = True

    # ====== 实验目录 & 日志 ======
    paths = setup_experiment_directory(root_dir=args.exp_root, experiment_name=args.exp_name, accelerator=accelerator)
    logger = create_logger(log_directory=paths["logs"], logger_name="main", accelerator=accelerator)
    logger.info(f"[EXP] root={paths['root']} | variant={args.variant} | seed={args.seed} ")
    logger.info("Running with the following configurations:")
    # 将 args 对象转换为字典
    args_dict = vars(args)
    # 使用 pprint 格式化字典
    formatted_args = pprint.pformat(args_dict)
    # 逐行打印，这样每行前面都会有 logger 的时间戳等信息
    for line in formatted_args.split("\n"):
        logger.info(line)

    # --- 构建编码器（student/anchor/switcher） ---
    student, anchor, switcher = build_teacher_student(args, accelerator.device)
    encoder = student.image_encoder

    head_dim0, mlp_dim0 = _dims_from_encoder(encoder)
    t_params, t_trainable = _count_params(student)
    logger.info(
        f"[INIT] embed_dim={encoder.embed_dim}, heads={encoder.num_heads}, "
        f"head_dim={head_dim0}, mlp_dim={mlp_dim0}, depth={encoder.depth} | "
        f"params={t_params / 1e6:.2f}M (trainable={t_trainable / 1e6:.2f}M)"
    )

    # --- 统一 DataLoader ---
    loaders = build_dataloaders(
        datasets_root={
            "train_pre": args.sa1b_train_root,
            "val_pre": args.sa1b_val_root,
            "train_down": args.samed2d_root,
            "val_down": args.samed2d_root,
        },
        batch_size=args.batch_size,
        train_size=args.train_size,
        grad_size=args.train_size,
        val_size=args.val_size,
        num_workers=4,
        prefetch_factor=None,
        pin_memory=None,
        subset_seed=args.seed,
    )
    grad_pre_dl = loaders["grad_pre"]
    grad_down_dl = loaders["grad_down"]
    train_down_dl = loaders["train_down"]

    # ======================
    # 读取模型后：初始验证（上游 + 下游）
    # ======================
    if loaders.get("val_pre", None) is not None:
        val_pre_loader = accelerator.prepare(loaders["val_pre"])
        accelerator.wait_for_everyone()
        pre_arr = evaluate_pretrain(
            student=student,
            switcher=switcher,
            loader=val_pre_loader,
            args=args,
            tag="pre_init",
            accelerator=accelerator,
        )
        logger.info(f"[EVAL:init:pre] metrics={pre_arr}")

    if loaders.get("val_down", None) is not None:
        val_down_loader = accelerator.prepare(loaders["val_down"])
        accelerator.wait_for_everyone()
        down_dict = evaluate_downstream(
            student=student,
            switcher=switcher,
            loader=val_down_loader,
            args=args,
            tag="down_init",
            accelerator=accelerator,
        )
        logger.info(f"[EVAL:init:down] {_fmt_down_metrics(down_dict, args.metrics)}")

    # 预建输入示例（构图与 MAC/Param 统计用）
    example = torch.randn(
        1,
        3,
        args.sammed2d_image_size,
        args.sammed2d_image_size,
        device=accelerator.device,
    )

    importance_placeholder = PGRGroupImportance(
        {},
        {},
        {},
        {},
        rank=1,
        energy_thresh=None,
        two_fold=False,
        weight_scale=False,
        fallback_p=1,
        seed=args.seed,
    )
    relpos = RelPosManager()

    # ======================
    # 两阶段：mode1 -> mode2
    # ======================
    for stage_name in ("mode3", "mode1"):
        logger.info(f"\n======== Start Stage: {stage_name} ========")

        mentor = None  # 每阶段开始时清空，阶段内逐轮更新
        is_mode1 = stage_name == "mode1"
        target_modules, ignored_layers, num_heads_map = collect_target_and_ignored_layers(encoder, mode=stage_name)
        pruning_ratio_dict = {tuple(target_modules): args.pruning_ratio}

        # 构图前暂删相对位置
        relpos.capture_current(encoder).delete(encoder)
        encoder.pos_embed.requires_grad_(True)
        if stage_name == "mode1":
            encoder.set_skip_adapter_mod()
        else:
            encoder.cancel_skip_adapter_mod()
        # 构图
        pruner = MagnitudePruner(
            model=encoder,
            example_inputs=example,
            importance=importance_placeholder,
            global_pruning=False,
            pruning_ratio=0.0,
            pruning_ratio_dict=pruning_ratio_dict,
            iterative_steps=args.pruning_rounds,
            num_heads=num_heads_map,
            prune_num_heads=False,
            prune_head_dims=True,
            forward_fn=lambda m, x: m(x)[0],
            output_transform=lambda out: out,
            root_module_types=[ops.TORCH_LINEAR, ops.TORCH_CONV],
            unwrapped_parameters=None if is_mode1 else [(encoder.pos_embed, 3)],
        )
        relpos.restore(encoder)

        for r in range(args.pruning_rounds):
            logger.info(f"\n---- {stage_name} | Round {r + 1}: collect grads & prune ----")

            # 播种采样器
            for dl in (grad_pre_dl, grad_down_dl, train_down_dl):
                try:
                    if hasattr(dl, "sampler") and hasattr(dl.sampler, "set_epoch"):
                        dl.sampler.set_epoch(r if stage_name == "mode1" else (1000 + r))
                except Exception:
                    pass

            # 收集 bank（日志简要大小信息）
            if is_mode1:
                # —— PRE：禁 adapter，开 non-adapter，确保有梯度 ——
                set_trainable_parts(
                    accelerator.unwrap_model(student),
                    train_image_encoder_adapter=False,
                    train_image_encoder_non_adapter=True,
                    train_prompt_encoder=False,
                    train_mask_decoder=False,
                )
                bank_pre_row, bank_pre_col = collect_bank(
                    grad_pre_dl,
                    encoder,
                    use_adapter=False,
                    accelerator=accelerator,
                    loss_fn=None,
                )
                bank_down_row, bank_down_col = collect_bank(
                    grad_down_dl,
                    encoder,
                    use_adapter=True,
                    accelerator=accelerator,
                    loss_fn=None,
                )
                logger.info(f"[BANK] pre(row/col) keys={len(bank_pre_row)}/{len(bank_pre_col)} | down(row/col) keys={len(bank_down_row)}/{len(bank_down_col)}")

            # 设置 PGR 打分器（记录关键超参）
            pruner.importance = (
                PGRGroupImportance(
                    bank_pre_row,
                    bank_pre_col,
                    bank_down_row,
                    bank_down_col,
                    rank=args.rank,
                    energy_thresh=args.energy,
                    two_fold=True,
                    weight_scale=True,
                    fallback_p=2,
                    seed=args.seed,
                )
                if stage_name in ["mode1", "mode2"]
                else torch_pruning.importance.MagnitudeImportance(p=2)
            )
            logger.info(f"[PGR] rank={args.rank} energy={args.energy} two_fold=True weight_scale=True fallback_p=2 seed={args.seed}")

            # ========== 剪枝前后统计（Encoder 级别） ==========
            # 注意：统计要在“完整相对位置”下；因此放在 delete 之前和 restore 之后
            mac0, para0 = count_ops_and_params(encoder, example)  # before
            relpos.delete(encoder)
            pruner.step()
            relpos.restore(encoder)
            mac1, para1 = count_ops_and_params(encoder, example)  # after
            accelerator.wait_for_everyone()

            # 百分比
            para_red_pct = ((para0 - para1) / para0 * 100.0) if para0 > 0 else 0.0
            mac_red_pct = ((mac0 - mac1) / mac0 * 100.0) if mac0 > 0 else 0.0

            model_name = f"{args.variant}-encoder"
            pruning_mode = stage_name
            logger.info(
                f"\n[Pruning Summary] Model: {model_name} | Mode: {pruning_mode}\n"
                f"  - Parameters: {_format_number(para0)} -> {_format_number(para1)} "
                f"(Reduced by {para_red_pct:.2f}%)\n"
                f"  - MACs:       {_format_number(mac0)} -> {_format_number(mac1)} "
                f"(Reduced by {mac_red_pct:.2f}%)"
            )
            # ===============================================

            accelerator.wait_for_everyone()

            student = sync_from_rank0_state_dict(accelerator, student)
            encoder = student.image_encoder

            accelerator.wait_for_everyone()

            # 统一蒸馏补偿（anchor + mentor）
            if stage_name == "mode1":
                args.epochs_per_round = args.epochs_per_round_mode1
                compensate_unified_distillation(
                    accelerator=accelerator,
                    student=student,
                    teacher_anchor=anchor,
                    teacher_prev=mentor,
                    switcher=switcher,
                    loaders={
                        "train_pre": loaders.get("train_pre", None),
                        "train_down": loaders.get("train_down", None),
                        "val_pre": loaders.get("val_pre", None),
                        "val_down": loaders.get("val_down", None),
                    },
                    epoch_num=args.epochs_per_round,
                    w_pre=args.w_pre,
                    w_down=args.w_down,
                    lr=args.lr,
                    args=args,
                    logger=logger,
                    best_ckpt_path_unified=paths["checkpoints"] / f"best_unified_{stage_name}_R{r + 1}.pth",
                )
                accelerator.wait_for_everyone()
                # ★★★★★ 在这一步做一次“阶段可视化”（用 val_down 看几张）
                if loaders.get("val_down") is not None:
                    visualize_downstream_once(
                        args=args,
                        student=student,
                        switcher=switcher,
                        loader=accelerator.prepare(loaders["val_down"]),  # 或者你想看的任意下游 loader
                        accelerator=accelerator,
                        tag=f"unified_{stage_name}_round{r + 1}",  # 自定义
                        out_root=paths["visualizations"],  # 用你已有的目录
                        top_k=16,  # 全局每模态 K
                        top_k_by_modality=None,  # 如需每模态不同K可传 dict
                    )
                accelerator.wait_for_everyone()
            # 下游补偿（细化）
            if stage_name == "mode3":
                args.epochs_per_round = args.epochs_per_round_mode3
                compensate_downstream_adapter_only(
                    accelerator=accelerator,
                    args=args,
                    student=student,
                    switcher=switcher,
                    loaders={
                        "train_down": loaders.get("train_down", None),
                        "val_down": loaders.get("val_down", None),
                    },
                    epoch_num=args.epochs_per_round,
                    logger=logger,
                    best_ckpt_path_down=paths["checkpoints"] / f"best_down_{stage_name}_R{r + 1}.pth",
                )
                accelerator.wait_for_everyone()  # 保险起见，同步一下
                if loaders.get("val_down") is not None:
                    visualize_downstream_once(
                        args=args,
                        student=student,
                        switcher=switcher,
                        loader=accelerator.prepare(loaders["val_down"]),
                        accelerator=accelerator,
                        tag=f"down_{stage_name}_round{r + 1}",
                        out_root=paths["visualizations"],
                        top_k=16,
                    )
                accelerator.wait_for_everyone()
            # 同步 encoder 引用
            encoder = student.image_encoder

            # 更新 mentor：用“本轮学生”制作冻结教师，供下一轮优先对齐
            mentor = copy.deepcopy(student).to(accelerator.device)
            mentor.eval()
            for p in mentor.parameters():
                p.requires_grad = False

            # 轮后维度 & 参数
            head_dim, mlp_dim = _dims_from_encoder(encoder)
            t_params, t_trainable = _count_params(student)
            logger.info(
                f"[{stage_name} | Round {r + 1}] head_dim={head_dim} | mlp_dim={mlp_dim} | params={t_params / 1e6:.2f}M (trainable={t_trainable / 1e6:.2f}M)"
            )

            # # 轮后评测（上/下游）
            # if loaders.get("val_pre", None) is not None:
            #     pre_arr = evaluate_pretrain(
            #         student=student,
            #         switcher=switcher,
            #         loader=loaders["val_pre"],
            #         args=args,
            #         accelerator=accelerator,
            #         tag=f"pre_S{stage_name}_R{r + 1}",
            #     )
            #     logger.info(f"[EVAL:round:pre] metrics={pre_arr}")

            # if loaders.get("val_down", None) is not None:
            #     down_dict = evaluate_downstream(
            #         student=student,
            #         switcher=switcher,
            #         loader=loaders["val_down"],
            #         args=args,
            #         accelerator=accelerator,
            #         tag=f"down_S{stage_name}_R{r + 1}",
            #     )
            #     logger.info(
            #         f"[EVAL:round:down] {_fmt_down_metrics(down_dict, args.metrics)}"
            #     )

        # 阶段收尾：维度
        head_dim_s, mlp_dim_s = _dims_from_encoder(encoder)
        logger.info(f"======== End Stage: {stage_name} | head_dim={head_dim_s} | mlp_dim={mlp_dim_s} ========")

    # --- 汇报最终维度 ---
    head_dim_f, mlp_dim_f = _dims_from_encoder(encoder)
    logger.info(
        f"\nFinal: head_dim {head_dim0} -> {head_dim_f} (~{int(math.floor(head_dim0 * 0.5))}) | mlp_dim {mlp_dim0} -> {mlp_dim_f} (~{int(math.floor(mlp_dim0 * 0.5))})"
    )
    logger.info("Done.")


def sync_from_rank0_state_dict(accelerator, student):
    """
    用 rank0 的 state_dict 覆盖所有 rank（accelerate 实现）。
    前提：各 rank 已完成相同剪枝，模型结构一致。
    """
    from accelerate.utils import broadcast_object_list

    accelerator.wait_for_everyone()
    payload = [accelerator.get_state_dict(student) if accelerator.is_main_process else None]
    broadcast_object_list(payload, from_process=0)
    sd = payload[0]

    target = accelerator.unwrap_model(student) if hasattr(accelerator, "unwrap_model") else student
    target.load_state_dict(sd, strict=True)
    accelerator.wait_for_everyone()
    return student


if __name__ == "__main__":
    main()
