import argparse
import os
import json
from tqdm import tqdm
import torch
import logging
from src.dataset_utils import *
from src.models import create_model
from src.defense import *
from src.baselines import *
from src.attack import *
from src.helper import get_log_name
from src.sampleMIS import *
import matplotlib.pyplot as plt
import pandas as pd
from llm_judge import LLMJudge
import time
import csv

# API Keys (保持原样)


def parse_args():
    parser = argparse.ArgumentParser(description='Robust Dynamic RAG')
    # --- 基础设置 ---
    parser.add_argument("--local_rank", type=int, default=0, help="Local rank")
    parser.add_argument('--model_name', type=str, default='mistral7b',
                        choices=['mistral7b', 'llama3b', 'gpt-4o', 'gpt-4o-mini', 'o1-mini',
                                 'deepseek7b', 'llama1b', 'tai_llama8b', 'tai_mistral7b',
                                 'deepseek-chat', 'deepseek-reasoner', 'gpt-4o', 'grok-4-fast'],
                        help='LLM model name')
    parser.add_argument('--dataset_name', type=str, default='dynamic_serpapi', help='dataset name')
    parser.add_argument('--model_dir', type=str, help='directory for huggingface models')
    parser.add_argument('--rep', type=int, default=1, help='Repeat times')

    # --- 动态与检索设置 ---
    parser.add_argument('--top_k', type=int, default=50, help='Total retrieved documents to use')
    parser.add_argument('--initial_k', type=int, default=1, help='Number of documents in the initial batch')
    parser.add_argument('--dynamic_step_size', type=int, default=1, help='Number of documents added in each subsequent step')

    # --- 防御方法 (整合了静态测试的所有方法) ---
    parser.add_argument('--defense_method', type=str, default='mincut',
                        choices=['none', 'mincut', 'cluster', 'voting', 'keyword', 'decoding',
                                 'sampling', 'astuterag', 'instructrag_icl', 'graph', 'MIS',
                                 'sampling_keyword', 'sampleMIS', 'em_based'],
                        help='The defense method to use')

    # --- 防御参数 (从静态代码迁移过来) ---
    parser.add_argument('--alpha', type=float, default=0.3, help='keyword filtering threshold alpha')
    parser.add_argument('--beta', type=float, default=3.0, help='keyword filtering threshold beta')
    parser.add_argument('--eta', type=float, default=0.0, help='decoding confidence threshold eta')
    parser.add_argument('--T', type=int, default=20, help='number of samples for sampling method')
    parser.add_argument('--m', type=int, default=2, help='number of docs per sample for sampling method')
    parser.add_argument('--gamma', type=float, default=0.9, help='weight discount factor for reliability-aware methods')
    parser.add_argument('--err', type=float, default=0, help='the added error probability of NLI')

    # NLI 相关参数
    parser.add_argument('--nli_model_path', type=str, default="DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
                        help='Path to NLI model')
    parser.add_argument('--nli_batch_size', type=int, default=32, help='Batch size for NLI')

    # Cluster 方法特有参数
    parser.add_argument('--sim_threshold', type=float, default=0.5, help='Threshold for NLI entailment in clustering')

    # EM-based 参数
    parser.add_argument('--prior_mu', type=float, default=0.65, help='Prior mu for EM')
    parser.add_argument('--prior_kappa', type=float, default=4.0, help='Prior kappa for EM')
    parser.add_argument('--em_mode', type=str, default='variational', choices=['variational', 'map'], help='EM mode')
    parser.add_argument('--select_top_m', type=int, default=None, help='Select top M docs')
    parser.add_argument('--prob_threshold', type=float, default=0.5, help='Probability threshold for selection')
    parser.add_argument('--ignore_idk', action='store_true', help='Ignore IDK in conflicts')
    parser.add_argument('--neutral_val', type=float, default=0.5, help='Neutral value for conflicts')
    parser.add_argument('--damping', type=float, default=0.5, help='Damping for EM')
    parser.add_argument('--max_iter', type=int, default=200, help='Max iterations for EM')
    parser.add_argument('--tol', type=float, default=1e-4, help='Tolerance for EM')
    parser.add_argument('--normalize', action='store_true', help='Normalize in EM')
    parser.add_argument('--seed', type=int, default=None, help='Seed for EM')

    # Other defense params
    parser.add_argument('--temperature', type=float, default=0, help='The temperature for softmax')

    # --- 攻击设置 ---
    parser.add_argument('--attack_method', type=str, default='none', choices=['none', 'Poison', 'PIA'],
                        help='Attack method')
    parser.add_argument('--attackpos', type=int, default=0, help='Position of attack in the full top-k list')
    parser.add_argument('--corruption_size', type=int, default=1, help='Number of poisoned documents')
    parser.add_argument('--attack_each_step', action='store_true', help='Attack during each dynamic step')

    # --- 其他 ---
    parser.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser.add_argument('--save_response', action='store_true', help='Save JSON results')
    parser.add_argument('--use_cache', action='store_true', help='Use LLM cache')
    parser.add_argument('--use_open_model_api', action='store_true', help='Use API for open models')
    parser.add_argument('--max_samples', type=int, default=None, help='Limit number of samples for testing')

    args = parser.parse_args()
    return args


def main():
    print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
    args = parse_args()

    # 设置环境变量用于 LLMJudge
    os.environ["SANDBOX_GRADER_MODEL"] = args.model_name

    LOG_NAME = get_log_name(args) + f"_dynamic_init{args.initial_k}_step{args.dynamic_step_size}"

    # ========== 0) Logger 初始化 ==========
    logging_level = logging.DEBUG if args.debug else logging.INFO
    os.makedirs('log', exist_ok=True)
    logging.basicConfig(format=':::::::::::::: %(message)s')
    logger = logging.getLogger('RRAG-main')
    logger.setLevel(level=logging_level)
    if logger.hasHandlers():
        logger.handlers.clear()
    logger.addHandler(logging.FileHandler(f"log/{LOG_NAME}.log"))
    logger.addHandler(logging.StreamHandler())
    logger.info(args)

    # ========== 1) 加载动态数据集 ==========
    dynamic_json_path = getattr(args, "dynamic_json_path", "data/poisoned_dynamic_dataset_500.json")
    data_tool = DynamicDataset(dynamic_json_path, args.top_k, logger)
    data_list = load_json(dynamic_json_path)
    if args.max_samples is not None:
        data_list = data_list[:args.max_samples]

    # ========== 2) 初始化 LLM ==========
    if args.use_cache:
        os.makedirs('cache/', exist_ok=True)
        cache_path = f'cache/{args.model_name}.z'
    else:
        cache_path = None

    # 判断是否为 biogen 类型的 long context 任务
    is_biogen = getattr(args, "dataset_name", "") == 'biogen'
    llm = create_model(
        args.model_name,
        args.model_dir,
        args.use_open_model_api,
        cache_path=cache_path,
        max_output_tokens=2048 if is_biogen else 512
    )
    longgen = is_biogen  # 用于某些防御方法的参数

    # ========== 3) 初始化防御模型 (整合所有方法) ==========
    if args.defense_method == 'none':
        model = RRAG(llm)
    elif args.defense_method == 'mincut':
        model = DynamicMinCutRRAG(llm, nli_model_path=args.nli_model_path)
    elif args.defense_method == 'cluster':
        model = DynamicClusterBasedRRAG(llm, sim_threshold=args.sim_threshold)
    # --- 静态防御方法 ---
    elif args.defense_method == 'voting':
        model = WeightedMajorityVoting(llm)
    elif args.defense_method == 'keyword':
        model = WeightedKeywordAgg(llm, relative_threshold=args.alpha, absolute_threshold=args.beta,
                                  gamma=args.gamma, longgen=longgen)
    elif args.defense_method == 'decoding':
        if args.eta > 0 and not longgen:
            logger.warning(f"using non-zero eta {args.eta} for QA")
        model = WeightedDecodingAgg(llm, eta=args.eta, gamma=args.gamma)
    elif args.defense_method == 'graph':
        model = GraphBasedRRAG(llm)
    elif args.defense_method == 'MIS':
        model = MISBasedRRAG(llm, err=args.err)
    elif args.defense_method == "sampling":
        model = RandomSamplingReQueryAgg(llm=llm, sample_size=args.m, num_samples=args.T, gamma=args.gamma)
    elif args.defense_method == "sampling_keyword":
        model = SamplingWithKeyWordAggregation(
            llm=llm, sample_size=args.m, num_samples=args.T, gamma=args.gamma,
            relative_threshold=args.alpha, absolute_threshold=args.beta, abstention_threshold=1
        )
    elif args.defense_method == 'instructrag_icl':
        model = InstructRAG_ICL(llm)
    elif args.defense_method == 'astuterag':
        model = AstuteRAG(llm)
    elif args.defense_method == "sampleMIS":
        model = SampleMISRRAG(llm, sample_size=args.m, num_samples=args.T, gamma=args.gamma, err=args.err)
    else:
        raise ValueError(f"Invalid defense method: {args.defense_method}")

    # 如果需要 LLMJudge (AstuteRAG / InstructRAG)
    if args.defense_method in ['astuterag', 'instructrag_icl']:
        llm_judge = LLMJudge(model=args.model_name)
    else:
        llm_judge = None

    # ========== 4) 初始化攻击者 ==========
    if args.attack_method == 'none':
        attacker = None
    elif args.attack_method == 'PIA':
        cls = PIALONG if is_biogen else PIA
        attacker = cls(top_k=args.top_k, repeat=10, poison_pos=args.attackpos, poison_num=args.corruption_size)
    elif args.attack_method == 'Poison':
        cls = PoisonLONG if is_biogen else Poison
        attacker = cls(top_k=args.top_k, repeat=10, poison_pos=args.attackpos, poison_num=args.corruption_size)
    else:
        raise ValueError("Invalid attack method")

    # ========== 5) 准备结果 CSV ==========
    os.makedirs('output', exist_ok=True)
    output_csv_file = f"./output/{LOG_NAME}.csv"

    # ✅ 修改：同时保存 step-wise 与 final-per-question 指标
    fieldnames = [
        "rep_idx",
        "step_acc", "step_asr", "total_steps",
        "final_acc", "final_asr", "num_questions",
        "input_tokens", "output_tokens", "total_time_sec",
        "defense_method", "attack_method", "dataset_name", "step_size", "init_k"
    ]
    with open(output_csv_file, mode='w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

    response_list = []

    # ========== 6) 主循环：rep 次重复实验 ==========
    for rep_idx in range(args.rep):
        corr_cnt = 0
        asr_cnt = 0
        input_tokens = 0
        output_tokens = 0
        total_time = 0

        # 指标统计：所有问题所有步
        total_step_correct = 0
        total_step_asr_success = 0
        total_steps = 0

        # 额外统计
        perfect_questions = 0
        perfect_asr_questions = 0

        # ========== 7) 遍历样本 ==========
        for data_idx, raw_item in enumerate(tqdm(data_list)):
            logger.info(f'==== rep_idx #{rep_idx}; item: {data_idx} ====')
            data_item = data_tool.process_data_item(raw_item)

            # 初始攻击（如果不每步攻击）
            if attacker and not args.attack_each_step:
                data_item = attacker.attack(data_item)

            yearly_contexts = data_item.get("yearly_contexts", {}) or {}
            years_sorted = sorted([int(y) for y in yearly_contexts.keys()])
            total_years = len(years_sorted)

            # 状态变量 (Dynamic Only)
            current_ans = None
            priors = None
            cluster_state = []

            # 累积变量
            accumulated_context_str = []

            llm.reset_token_count()
            start_time = time.perf_counter()

            year_ptr = 0
            end_ptr = min(total_years, int(args.initial_k))
            step_count = 0

            all_steps_correct = True
            all_steps_asr_success = True if attacker else False

            while year_ptr < total_years:
                batch_years = years_sorted[year_ptr:end_ptr]
                new_docs_batch = []

                # --- 获取该批次文档 ---
                for y in batch_years:
                    year_str = str(y)
                    year_data = yearly_contexts.get(year_str, {})
                    docs = year_data.get("docs", []) or []

                    adapted_docs = []
                    for d in docs:
                        title = d.get("title", "").strip()
                        snippet = d.get("snippet", "").strip()
                        content = d.get("content", "").strip()
                        full_text = f"[Title] {title}\n[Snippet] {snippet}\n[Content] {content}".strip()
                        adapted_doc = {
                            "title": title,
                            "text": full_text,
                            "year": y,
                            "month": 0,
                            "id": d.get("id", ""),
                            "sorting_key": 0
                        }
                        adapted_docs.append(adapted_doc)
                    new_docs_batch.extend(adapted_docs)

                if not new_docs_batch:
                    year_ptr = end_ptr
                    end_ptr = min(total_years, end_ptr + int(args.dynamic_step_size))
                    continue

                step_count += 1
                logger.info(f"--- Step {step_count}: Adding docs for years {batch_years} ---")

                # --- 准备问题对象 ---
                latest_year = str(batch_years[-1])
                original_question = data_item['question']
                #modified_question = f"{original_question} in {latest_year}"

                step_data_item = data_item.copy()
                step_data_item["question"] = original_question#modified_question

                # 设置 Ground Truth
                latest_year_data = yearly_contexts.get(latest_year, {})
                step_data_item["answer"] = data_tool._coerce_list(latest_year_data.get("answer", []))
                step_data_item["incorrect_answer"] = data_tool._coerce_list(latest_year_data.get("incorrect_answer", []))
                step_data_item["incorrect_context"] = data_tool._coerce_list(latest_year_data.get("incorrect_context", []))

                # --- 文档处理流程：生成 -> 攻击 -> (累积) -> 注入 ---

                # 1. 生成当前步的内容
                current_step_content = docs_to_topk_content(new_docs_batch, include_title=True)
                step_data_item["topk_content"] = current_step_content

                # 2. 如果开启每步攻击，先攻击当前步
                if args.attack_each_step and attacker:
                    step_data_item = attacker.attack(step_data_item)

                # 3. 区分动态方法与静态方法，决定是否累积历史
                DYNAMIC_METHODS = ['mincut', 'cluster']

                if args.defense_method in DYNAMIC_METHODS:
                    # MinCut/Cluster 只传入增量（已被攻击）
                    pass
                else:
                    # 静态方法：把“当前步 + 历史”一起喂
                    final_current_text = step_data_item.get("topk_content", [])

                    # 统一：当前步 topk_content 必须是 list
                    if isinstance(final_current_text, str):
                        final_current_text = [final_current_text]
                    elif final_current_text is None:
                        final_current_text = []
                    else:
                        final_current_text = list(final_current_text)

                    # 累积：当前步放前面（current + history）
                    accumulated_context_str = final_current_text + accumulated_context_str

                    # 写回：保持 list[str]
                    step_data_item["topk_content"] = accumulated_context_str

                # --- 执行查询 ---
                if args.defense_method == 'none':
                    current_ans = model.query_undefended(step_data_item)

                elif args.defense_method == 'mincut':
                    current_ans, priors = model.dynamic_query(
                        step_data_item, previous_answer=current_ans, previous_priors=priors
                    )

                elif args.defense_method == 'cluster':
                    current_ans, cluster_state = model.dynamic_query(
                        step_data_item, previous_state=cluster_state
                    )

                else:
                    current_ans = model.query(step_data_item)

                # --- 后处理 (针对需要 LLMJudge 的方法) ---
                if args.defense_method in ['astuterag', 'instructrag_icl'] and llm_judge is not None:
                    current_ans = llm_judge.judge(step_data_item["question"], current_ans)

                logger.info(f"Step {step_count} Answer: {current_ans}")

                # --- 评测（按步统计！） ---
                is_correct = data_tool.eval_response(current_ans, step_data_item)
                is_asr = data_tool.eval_response_asr(current_ans, step_data_item) if attacker else 0

                logger.info(f"Step {step_count} correct: {is_correct}, asr: {is_asr}")

                total_steps += 1
                total_step_correct += int(is_correct)
                if attacker:
                    total_step_asr_success += int(is_asr)

                if not is_correct:
                    all_steps_correct = False
                if attacker and not is_asr:
                    all_steps_asr_success = False

                year_ptr = end_ptr
                end_ptr = min(total_years, end_ptr + int(args.dynamic_step_size))

            # --- 样本结束（按“最终回答”统计，可保留做对照） ---
            final_response = current_ans if current_ans else "I don't know."
            final_correct = data_tool.eval_response(final_response, data_item)
            final_asr = data_tool.eval_response_asr(final_response, data_item) if attacker else 0

            corr_cnt += int(final_correct)
            if attacker:
                asr_cnt += int(final_asr)

            if all_steps_correct:
                perfect_questions += 1
            if attacker and all_steps_asr_success:
                perfect_asr_questions += 1

            response_list.append({
                "query": data_item['question'],
                "final_response": final_response,
                "defense": args.defense_method,
                "is_correct": bool(final_correct),
                "steps_taken": step_count
            })

            input_tokens += llm.get_token_count().get("input", 0)
            output_tokens += llm.get_token_count().get("output", 0)
            total_time += time.perf_counter() - start_time

        # ========== 8) rep 级别总结 ==========
        logger.info(f'\n=== Result for rep: {rep_idx} ===')
        num_questions = len(data_list)

        # final-per-question
        final_acc = corr_cnt / num_questions if num_questions > 0 else 0.0
        final_asr = asr_cnt / num_questions if (num_questions > 0 and attacker) else 0.0

        # step-wise
        step_acc = total_step_correct / total_steps if total_steps > 0 else 0.0
        step_asr = total_step_asr_success / total_steps if (total_steps > 0 and attacker) else 0.0

        logger.info(f'Total Steps: {total_steps}')
        logger.info(f'Step-wise Avg Accuracy: {step_acc:.4f}')
        if attacker:
            logger.info(f'Step-wise Avg ASR: {step_asr:.4f}')

        # 可选：保留最终回答指标，便于对照
        logger.info(f'Final (per-question last answer) Accuracy: {final_acc:.4f}')
        if attacker:
            logger.info(f'Final (per-question last answer) ASR: {final_asr:.4f}')

        # 保存 CSV
        if args.use_cache:
            llm.dump_cache()

        result_current = {
            "rep_idx": rep_idx,

            "step_acc": step_acc,
            "step_asr": step_asr,
            "total_steps": total_steps,

            "final_acc": final_acc,
            "final_asr": final_asr,
            "num_questions": num_questions,

            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "total_time_sec": round(total_time, 2),

            "defense_method": args.defense_method,
            "attack_method": args.attack_method,
            "dataset_name": "dynamic_serpapi",
            "step_size": args.dynamic_step_size,
            "init_k": args.initial_k
        }
        df = pd.DataFrame([result_current])
        df.to_csv(output_csv_file, mode='a', header=False, index=False)


if __name__ == '__main__':
    main()
