from _common import *

log = logging.getLogger(__name__)

import types
from src.clip_eval import eval_single_dataset
from src.task_vectors import StateDict, TaskVector, state_dict_mean
from src.ties_merging_utils import check_parameterNamesMatch

from clip_checkpoint_path import CHECKPOINT_DIR, finetuned_model_path, pretrained_model_path

from src.task_vectors import NonLinearTaskVector
from singular_vector import TSVM_utils
from src.tasks.shortest_route_classification_heads import *

# ==== NEW
import os
import numpy as np
import pandas as pd
# ==== NEW END


@hydra.main(config_path=str(CONFIG_DIR), config_name="default", version_base=None)
def main(cfg: DictConfig) -> None:
    cfg.save = str(CHECKPOINT_DIR / cfg.model)
    cfg.data_location = str(DATA_DIR)
    model = cfg.model

    log.info("load finetuned models")
    task_vectors: List[StateDict] = [
        TaskVector(
            pretrained_checkpoint=pretrained_model_path(cfg.model),
            finetuned_checkpoint=finetuned_model_path(cfg.model, dataset_name),
        ).vector
        for dataset_name in tqdm(cfg.datasets)
    ]
    check_parameterNamesMatch(task_vectors)

    # ==== NEW: BWT 相关缓存与配置 ====
    acc_when_learned = {}              # R_{i,i}
    acc_step = {}                      # {k: {dataset: R_{k,i}}}
    bwt_step_values = {}               # {k: BWT_k}
    bwt_milestones = [8, 14, 20]       # 需要输出的里程碑
    results_dir = RESULTS_DIR / cfg.model
    os.makedirs(results_dir, exist_ok=True)
    bwt_steps_path = results_dir / "clip_tsvm_bwt_steps.txt"
    bwt_final_path = results_dir / "clip_tsvm_bwt.txt"
    bwt_summary_path = results_dir / "clip_tsvm_bwt_summary.txt"
    # ==== NEW END

    def _compute_and_log_bwt_for_step(k: int):
        """BWT_k = mean_{i in seen[1:]} ( R_{k,i} - R_{i,i} ), seen = cfg.datasets[:k]"""
        try:
            datasets_seen = list(cfg.datasets)[:k]
            if len(datasets_seen) < 2:
                log.warning(f"[BWT@{k}] Need at least 2 tasks.")
                return
            if k not in acc_step:
                log.warning(f"[BWT@{k}] acc_step[{k}] not found.")
                return
            terms = []
            for name in datasets_seen[1:]:
                if (name in acc_when_learned) and (name in acc_step[k]):
                    terms.append(acc_step[k][name] - acc_when_learned[name])
            if len(terms) == 0:
                log.warning(f"[BWT@{k}] No valid terms (missing R_i_i or R_k_i).")
                return
            bwt_k = float(sum(terms) / len(terms))
            bwt_step_values[k] = bwt_k
            # 追加写入步骤结果
            with open(bwt_steps_path, "a") as f:
                f.write(f"{k}-task BWT: {bwt_k:.6f}\n")
            log.info(f"[BWT] {k}-task BWT = {bwt_k:.4f} (avg over {len(terms)} tasks)")
        except Exception as e:
            log.warning(f"[BWT@{k}] compute/log failed: {e}")

    TSVM_state_dict = None  # 保持最后一轮的权重

    for i in range(len(task_vectors) - 1):
        print(f"Computing mean state dict for checkpoints {i} and {i + 1}")
        continual_ft = [task_vectors[i], task_vectors[i + 1]]

        config = types.SimpleNamespace()
        config.DATASETS = ["0", "1"]
        config.device = "cuda:0"
        TSVM_state_dict = TSVM_utils.compute_and_sum_svd_mem_reduction(continual_ft, config)

        # 更新“至今为止学到的”权重
        task_vectors[i + 1] = TSVM_state_dict

        # ---- 评估 R_{i+1,i+1}（学完第 i+1 个任务） ----
        try:
            step_vector = NonLinearTaskVector(vector=TSVM_state_dict)
            for kname in list(step_vector.vector.keys()):
                step_vector.vector[kname] = step_vector.vector[kname].cpu()
            step_encoder = step_vector.apply_to(pretrained_model_path(model))

            just_learned_dataset = cfg.datasets[i + 1]
            metrics_step = eval_single_dataset(step_encoder, just_learned_dataset, cfg)
            acc_ii = float(metrics_step.get("top1"))
            acc_when_learned[just_learned_dataset] = acc_ii
            log.info(f"[BWT trace] After learning {just_learned_dataset}: R_i_i={acc_ii:.4f}")

            # ---- 里程碑：学到第 k=i+2 个任务时计算 BWT_k ----
            k = i + 2  # 已学任务数
            total_tasks = len(cfg.datasets)
            if (k in bwt_milestones) and (k <= total_tasks):
                if k > len(cfg.datasets):
                    log.warning(f"[BWT step {k}] skipped (datasets < {k}).")
                else:
                    # 收集 R_{k,i}：对前 k 个任务逐一评估
                    acc_step[k] = {}
                    for name in cfg.datasets[:k]:
                        try:
                            m = eval_single_dataset(step_encoder, name, cfg)
                            acc_ki = float(m.get("top1"))
                            acc_step[k][name] = acc_ki
                            log.info(f"[BWT trace] Step {k}: R_k,i on {name} = {acc_ki:.4f}")
                        except Exception as ee:
                            log.warning(f"[BWT step {k}] eval on {name} failed: {ee}")
                    _compute_and_log_bwt_for_step(k)

            # 释放暂存模型
            del step_encoder
        except Exception as e:
            log.warning(f"[BWT trace] Eval R_i_i / milestone failed at step {i+1}: {e}")

    # ==== 最终模型评估 ====
    merged_vector = NonLinearTaskVector(vector=TSVM_state_dict)
    for key in merged_vector.vector:
        merged_vector.vector[key] = merged_vector.vector[key].cpu()
    image_encoder = merged_vector.apply_to(pretrained_model_path(model))

    results = {"dataset": [], "acc": []}
    accs = []
    final_acc = {}  # R_{T,i}

    for dataset in cfg.datasets:
        metrics = eval_single_dataset(image_encoder, dataset, cfg)
        log.info(str(dataset) + ":" + str(metrics.get("top1") * 100) + "%")
        acc = float(metrics.get("top1"))
        accs.append(acc * 100.0)

        results["dataset"].append(dataset)
        results["acc"].append(acc)
        final_acc[dataset] = acc

    log.info("Avg ACC:" + str(np.mean(accs)) + "%")

    log.info("Eval: All tasks accuracy:")
    for acc in results["acc"]:
        print(f"{acc:.3f}", end=" ")

    df = pd.DataFrame(results)
    df.to_csv(results_dir / "averaging.csv", index=False)

    # ==== 计算最终（T 任务）BWT ====
    try:
        datasets = list(cfg.datasets)
        if len(datasets) >= 2:
            terms = []
            for name in datasets[1:]:
                if name in acc_when_learned and name in final_acc:
                    terms.append(final_acc[name] - acc_when_learned[name])
            if len(terms) > 0:
                bwt = float(sum(terms) / len(terms))
                with open(bwt_final_path, "w") as f:
                    f.write(f"{bwt:.6f}\n")
                log.info(f"[BWT] Final Backward Transfer = {bwt:.4f}  (avg over {len(terms)} tasks)")
            else:
                log.warning("[BWT] No valid terms for final BWT (missing R_i_i or R_T_i).")
        else:
            log.warning("[BWT] Need at least 2 tasks for final BWT.")
    except Exception as e:
        log.warning(f"[BWT] Failed to compute final BWT: {e}")

    # ==== 可选：输出 8/14/20 的汇总（若存在） ====
    try:
        with open(bwt_summary_path, "w") as f:
            for k in [8, 14, 20]:
                if k in bwt_step_values:
                    f.write(f"{k}-task BWT: {bwt_step_values[k]:.6f}\n")
        log.info(f"[BWT] Milestone summary written to {bwt_summary_path}")
    except Exception as e:
        log.warning(f"[BWT] Failed to write milestone summary: {e}")


if __name__ == "__main__":
    main()
