import torch
import argparse
from src.configs import cfg
from src.model import build_model
from src.data import build_loader
from src.optim import build_optimizer
from src.adapter import build_adapter
from tqdm import tqdm
from setproctitle import setproctitle
import numpy as np
import os.path as osp
import os
import torch.multiprocessing
import pandas as pd
from src.utils import set_random_seed
from datetime import datetime
import time  # <<< 新增：用于高精度计时

def recurring_test_time_adaptation(cfg):
    # Building model, optimizer and adapter:
    model = build_model(cfg)
    # Building optimizer
    optimizer = build_optimizer(cfg)

    # Initializing TTA adapter
    tta_adapter = build_adapter(cfg)
    tta_model = tta_adapter(cfg, model, optimizer)
    tta_model.cuda()

    # Building data loader
    loader, processor = build_loader(cfg, cfg.CORRUPTION.DATASET, cfg.CORRUPTION.TYPE, cfg.CORRUPTION.SEVERITY)

    # Save logs
    outputs_arr = []
    labels_arr = []

    # Main test-time Adaptation loop
    tbar = tqdm(loader)
    for batch_id, data_package in enumerate(tbar):
        data, label, domain = data_package["image"], data_package['label'], data_package['domain']

        if len(label) == 1:
            continue  # ignore the final single point

        data, label = data.cuda(), label.cuda()
        # 建议把 domain 一并传入（若适配器支持）
        try:
            output = tta_model(data, label={"label": label, "domain": domain.cuda()})
        except Exception:
            # 兼容老适配器只接收 label 的情况
            output = tta_model(data, label=label)

        outputs_arr.append(output.detach().cpu().numpy())
        labels_arr.append(label.detach().cpu().numpy())

        predict = torch.argmax(output, dim=1)
        accurate = (predict == label)
        processor.process(accurate, domain)

        tbar.set_postfix(acc=processor.cumulative_acc())

    labels_arr = np.concatenate(labels_arr, axis=0)
    outputs_arr = np.concatenate(outputs_arr, axis=0)

    processor.calculate()
    _, prcss_eval_csv = processor.info()
    return prcss_eval_csv, tta_model

def main():
    parser = argparse.ArgumentParser("Pytorch Implementation for Test Time Adaptation!")
    parser.add_argument(
        '-acfg',
        '--adapter-config-file',
        metavar="FILE",
        default="",
        help="path to adapter config file",
        type=str)
    parser.add_argument(
        '-dcfg',
        '--dataset-config-file',
        metavar="FILE",
        default="",
        help="path to dataset config file",
        type=str)
    parser.add_argument(
        'opts',
        help='modify the configuration by command line',
        nargs=argparse.REMAINDER,
        default=None)

    # Parsing arguments
    args = parser.parse_args()
    if len(args.opts) > 0:
        args.opts[-1] = args.opts[-1].strip('\r\n')

    # Merge configs
    cfg.merge_from_file(args.adapter_config_file)
    cfg.merge_from_file(args.dataset_config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # Debug prints for sanity check
    print("ADAPTER.NAME:", cfg.ADAPTER.NAME)
    print("MODEL.ARCH:", cfg.MODEL.ARCH)
    print("CKPT_DIR:", cfg.CKPT_DIR)
    print("CKPT_PATH:", cfg.CKPT_PATH)

    setproctitle(f"TTA:{cfg.CORRUPTION.DATASET:>8s}:{cfg.ADAPTER.NAME:<10s}")

    # For reproducibility
    torch.backends.cudnn.benchmark = True
    set_random_seed(cfg.SEED)

    # ===== 新增：根据 C2FTTA 参数拼接子目录 =====
    def _sanitize(v):
        # 将浮点数里的点替换为 p，避免路径里的 "."
        return str(v).replace('.', 'p')

    param_subdir = None
    try:
        # 优先：若有 C2FTTA，就按 C2FTTA 的习惯命名
        cap  = cfg.ADAPTER.C2FTTA.STMEM_CAPACITY
        clus = cfg.ADAPTER.C2FTTA.STMEM_MAX_CLUS
        topk = cfg.ADAPTER.C2FTTA.STMEM_TOPK_CLUS
        thr  = cfg.ADAPTER.C2FTTA.BASE_THRESHOLD
        param_subdir = f"cap{cap}_clus{clus}_topk{topk}_th{_sanitize(thr)}"
    except Exception:
        # 若没有 C2FTTA，则在是 ResiTTA 时用 ResiTTA 的关键超参命名
        try:
            if cfg.ADAPTER.NAME.lower() == "resitta":
                cap   = cfg.ADAPTER.RESITTA.CAPACITY
                uf    = cfg.ADAPTER.RESITTA.UPDATE_FREQUENCY
                bna   = cfg.ADAPTER.RESITTA.BN_ALPHA
                lbd   = cfg.ADAPTER.RESITTA.LAMBDA_BN_D
                emg   = cfg.ADAPTER.RESITTA.E_MARGIN
                param_subdir = f"cap{cap}_uf{uf}_bna{_sanitize(bna)}_lbd{_sanitize(lbd)}_em{_sanitize(emg)}"
        except Exception:
            pass

    base_output_dir = cfg.OUTPUT_DIR
    output_dir = osp.join(base_output_dir, param_subdir) if param_subdir else base_output_dir
    os.makedirs(output_dir, exist_ok=True)
    print("OUTPUT_DIR:", output_dir)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M")

    # ====== 新增：记录 TTA 时长（仅围绕 recurring TTA 主流程）======
    start_iso = datetime.now().isoformat(timespec='seconds')
    t0 = time.perf_counter()

    # Running recurring TTA
    prcss_eval_csv, _ = recurring_test_time_adaptation(cfg)

    t1 = time.perf_counter()
    end_iso = datetime.now().isoformat(timespec='seconds')
    elapsed_seconds = t1 - t0

    print(f"[TTA Timing] start={start_iso} end={end_iso} duration_sec={elapsed_seconds:.6f}")

    # Saving evaluation results to files:
    log_file_name = "%s_%s_%s" % (
        osp.basename(args.dataset_config_file).split('.')[0],
        osp.basename(args.adapter_config_file).split('.')[0],
        timestamp)

    result_path = osp.join(output_dir, "%s.csv" % log_file_name)
    with open(result_path, "w") as fo:
        fo.write(prcss_eval_csv)

        # 在 CSV 末尾追加计时信息（以注释行写入，方便用 pandas 的 comment='#' 读取）
        fo.write("\n")
        fo.write("# ==== TTA Timing (program-appended) ====\n")
        fo.write(f"# START_TIME,{start_iso}\n")
        fo.write(f"# END_TIME,{end_iso}\n")
        fo.write(f"# TTA_DURATION_SECONDS,{elapsed_seconds:.6f}\n")

    print(f"Saved results to: {result_path}")

if __name__ == "__main__":
    main()
