import random
import json
import os
# Import pgmpy to build Bayesian networks
from pgmpy.models import BayesianNetwork
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import VariableElimination
from tqdm import tqdm
from utils.response_parse_utils import safe_json_parse, parse_latents, parse_latents_prob
from itertools import product
from typing import List, Dict, Any, Tuple
import argparse
from utils import config
from utils.prompt_shots import generate_para_prob_shots, LLM_LATENT_SHOTS, \
    GENERATE_LATENT_PROBS_SHOTS, generate_para_prob_shots_cot_v1, generate_para_prob_shots_cot_v0, LLM_LATENT_SHOTS_COT
from utils.utils import print_args, ask_gpt
random.seed(42)
# Original calculation function for computing sentence support based on value_prob_list
def calculate_prob_for_sent(implied_values, value_prob_list):
    p_1_list = []
    for value_prob in value_prob_list:
        if set(implied_values).issubset(set(value_prob[0:-1])):
            p_1_list.append(value_prob[-1])
    if p_1_list:
        p_1 = sum(p_1_list) / len(p_1_list)
        p_2 = 1 - p_1
    else:
        p_1, p_2 = 0.5, 0.5
    if p_1 > 0.5:
        prediction = 'Statement 1'
    elif p_1 < 0.5:
        prediction = 'Statement 2'
    else:
        prediction = 'Unknown'
    return prediction, [p_1, p_2]

def get_outcome(mapping_addition_sentence_dic, value_prob_list):
    addition_sentence_final_prediction = {}
    addition_sentence_final_probability = {}
    for sent, value in mapping_addition_sentence_dic.items():
        if not value:
            addition_sentence_final_prediction[sent], addition_sentence_final_probability[sent] = 'Unknown', [0.5, 0.5]
        else:
            addition_sentence_final_prediction[sent], addition_sentence_final_probability[sent] = \
                calculate_prob_for_sent(value, value_prob_list)
    return addition_sentence_final_prediction, addition_sentence_final_probability




def build_naive_bn(factors, p_true,smooth_alpha):
    # Build lowercase → original key mapping
    temp_lower_map = {key.lower(): key for key in p_true.keys()}

    # Only construct p_false
    p_false = {}
    for f in factors:
        true_key = temp_lower_map.get(f.lower())
        if true_key is None:
            raise KeyError(f"Factor '{f}' not found in p_true (case‐insensitive)")
        p_false[f] = 1.0 - p_true[true_key]
    bn = BayesianNetwork([("Statement", f) for f in factors])
    bn.add_cpds(TabularCPD("Statement", 2, [[0.5], [0.5]]))
    for f in factors:
        raw_pt, raw_pf = p_true[f], p_false[f]
        pt = (raw_pt + smooth_alpha) / (1 + 2 * smooth_alpha)
        pf = (raw_pf + smooth_alpha) / (1 + 2 * smooth_alpha)
        bn.add_cpds(TabularCPD(
            f, 2,
            [[pt, pf], [1 - pt, 1 - pf]],
            evidence=["Statement"], evidence_card=[2]
        ))
    bn.check_model()
    return bn


def generate_valid_latents(
    factors: List[str],
    messages: List[Dict[str, str]],
    model_name: str,
    max_retries: int = 3
) -> List[Dict[str, Any]]:
    """
    Retry calling LLM until all factors in parsed latents are in the `factors` list (case-insensitive).
    Latent variable names and 'Statement' are not checked.
    Returns: latents list, each element is {"name": str, "factors": [str,...]}
    """
    # Pre-compute lowercase version of factors set
    factors_lower = {f.lower() for f in factors}

    for attempt in range(1, max_retries + 1):
        # Make a copy to avoid modifying original messages
        input_messages = list(messages)

        raw_spec = ask_gpt(
            messages=input_messages,
            model_name=model_name,
            max_token=1024
        )
        # print('[generate_valid_latents]------>>>>> resp: ', raw_spec)
        try:
            # Directly parse out latents list
            latents = parse_latents(raw_spec)
            if not isinstance(latents, list):
                raise ValueError("`latents` is not a list")
        except Exception as e:
            print(f"[Attempt {attempt}] JSON parse error: {e}, Response: {raw_spec}")
            continue

        # Validate each latent.factors
        missing = []
        for L in latents:
            if not isinstance(L, dict):
                missing.append(f"Invalid latent entry: {L}")
                continue
            name = L.get("name")
            facs = L.get("factors")
            if not isinstance(name, str) or not isinstance(facs, list):
                missing.append(f"Missing or invalid name/factors for latent {L}")
                continue
            for f in facs:
                if not isinstance(f, str) or f.lower() not in factors_lower:
                    missing.append(f)

        if missing:
            missing_str = ", ".join(sorted(set(missing)))
            print(f"[Attempt {attempt}] Unknown factors (case-insensitive): {missing_str}")
            # Prompt LLM to regenerate
            input_messages.append({
                "role": "user",
                "content": (
                    f"The following factors were not in the provided list (case-insensitive): "
                    f"{missing_str}. Please regenerate your JSON using only the original factors."
                )
            })
            # Next retry round, replace messages
            messages = input_messages
            continue

        # All valid, return latents
        return latents

    raise RuntimeError(
        f"Failed to generate valid latents after {max_retries} attempts."
    )


def generate_cpt_from_stats(
    latents: List[Dict[str, Any]],
    latent_names: List[str],
    statement1: str,
    statement2: str,
    args,
    prior: float = 0.5
):
    """
    Use LLM to self-judge the support probability (p1,p0) of each latent under Statement1/Statement2,
    then enumerate all combinations to calculate Statement's CPT.

    Args:
      - latents: All latent definitions, each dict contains 'name' and 'factors'
      - latent_names: Only generate probabilities for these latents
      - statement1, statement2: Two statements to compare
      - prob_list: Observation count list for each latent
    Returns:
      - [[P(Stmt=1|combos)], [P(Stmt=0|combos)]]
    """
    # Construct few-shot + user messages
    messages = GENERATE_LATENT_PROBS_SHOTS + [{
        "role": "user",
        "content": (
            f"Statement1: {statement1}\n"
            # f"Statement2: {statement2}\n\n"
            "Here are your latent variables and their factors:\n"
            f"{json.dumps(latents, ensure_ascii=False, indent=2)}\n\n"
            "Based on these, please return a JSON object mapping each latent name "
            "(must be one of the latent_names list) to a two-element array [p1, p0], where:\n"
            "  • p1 = P(latent=True | Statement1)\n"
            "  • p0 = P(latent=True | Not Statement1)\n"
            "Use floats between 0 and 1."
        )
    }]

    # Call LLM with retry
    cond: Dict[str, Tuple[float, float]] = {}
    for attempt in range(1, args.max_retries+1):
        raw = ask_gpt(messages, model_name=args.model_name,use_temp=args.use_temp, max_token=1024)
        # print('[generate_cpt_from_stats]------>>>>> resp: ', raw)
        try:
            llm_map = parse_latents_prob(raw)
            # Validate and store in cond
            for ln, arr in llm_map.items():
                if ln not in latent_names:
                    raise ValueError(f"Unexpected latent: {ln}")
                if not (isinstance(arr, list) and len(arr) == 2):
                    raise ValueError(f"Bad entry for {ln}: {arr}")
                p1, p0 = float(arr[0]), float(arr[1])
                if not (0 <= p1 <= 1 and 0 <= p0 <= 1):
                    raise ValueError(f"Prob out of bounds for {ln}: {arr}")
                cond[ln] = (p1, p0)
            break
        except Exception as e:
            print(f"[Attempt {attempt}] JSON parse error: {e}, Response: {raw}")
            # Provide error feedback to LLM and retry
            messages.append({
                "role": "user",
                "content": (
                    f"Your JSON was invalid: {e}. "
                    "Please return **only** the JSON mapping latent names to [p1, p0]."
                )
            })
    else:
        raise RuntimeError("Failed to get valid probabilities from LLM")

    # Enumerate all latent combinations, calculate Statement's CPT
    n = len(latent_names)
    true_row, false_row = [], []
    for state in product([1, 0], repeat=n):
        like1 = like0 = 1.0
        for xi, ln in zip(state, latent_names):
            p1, p0 = cond.get(ln, (0.5, 0.5))
            if xi:
                like1 *= p1
                like0 *= p0
            else:
                like1 *= (1 - p1)
                like0 *= (1 - p0)
        post1 = prior * like1
        post0 = (1 - prior) * like0
        norm = post1 + post0
        if norm > 0:
            true_row.append(post1 / norm)
            false_row.append(post0 / norm)
        else:
            true_row.append(prior)
            false_row.append(1 - prior)

    return [true_row, false_row],llm_map


def compute_cpt_from_stats(
    latent_names: List[str],
    prob_list: List[Dict[str, Dict[str, int]]],
    prior: float = 0.5,
    alpha: float = 0.5  # Laplace smoothing factor
) :
    """
    Laplace smoothed CPT calculation.
    Generate Statement's CPT from statistical data, supporting Statement1, Statement2, and Both cases:
      - latent_names: ["RiskLat", "EfficiencyLat", ...]
      - prob_list: [
            {'RiskLat': {'Statement1': 4, 'Statement2': 1, 'Both': 2}},
            ...
        ]
      - prior: Prior for P(Statement1), default 0.5

    For 'Both', we split its count equally between Statement1 and Statement2.
    """
    # 1. Construct (c1, c0) count for each latent and handle Both
    # 1. Calculate smoothed conditional probability P(x=1|S=1), P(x=1|S=2)
    cond = {}
    for entry in prob_list:
        for latent, counts in entry.items():
            c1 = counts.get("Statement1", 0) + alpha
            c0 = counts.get("Statement2", 0) + alpha
            both = counts.get("Neutral", 0)
            # If there's Both, also add alpha×0.5
            c1 += both * 0.5
            c0 += both * 0.5
            total = c1 + c0
            p1 = c1 / total
            p0 = c0 / total
            cond[latent] = (p1, p0)
    latent_prob_map = cond
    n = len(latent_names)
    true_row = []
    false_row = []

    # 2. Enumerate all latent combinations
    for state in product([1, 0], repeat=n):
        like1 = 1.0
        like0 = 1.0
        for xi, latent in zip(state, latent_names):
            p1, p0 = cond.get(latent, (0.5, 0.5))
            if xi == 1:
                like1 *= p1
                like0 *= p0
            else:
                like1 *= (1 - p1)
                like0 *= (1 - p0)
        post1 = prior * like1
        post0 = (1 - prior) * like0
        norm = post1 + post0
        if norm > 0:
            true_row.append(post1 / norm)
            false_row.append(post0 / norm)
        else:
            true_row.append(prior)
            false_row.append(1 - prior)

    return [true_row, false_row],latent_prob_map

def build_causal_bn(args,factors,p_true,factor_statement_mapping,statement1,statement2=None,latent=None,latent_prob_maps= None):
    """
    Fully dynamic causal Bayesian network:
    1. LLM identifies latent variables and their responsible factors, as well as edges;
    2. LLM generates CPT for Statement node;
    3. Generate two-column CPT for factors under each latent;
    Returns (bn, edges, latents).
    """

    # ─── 1. Call LLM to output latents + edges ───
    prompt = (
        "Please identify latent variables and assign each factor to a latent. "
        "Then return JSON with fields:\n"
        " latents: [{\"name\": string, \"factors\": [...]}, ...],\n"
        f"Factors: {json.dumps(factors, ensure_ascii=False)}"
    )
    if latent:
        latents = latent
    else:
        if 'cot' in args.use_cot:
            SHOT = LLM_LATENT_SHOTS_COT
        else:
            SHOT = LLM_LATENT_SHOTS
        messages = SHOT + [{"role":"user","content":prompt}]
        try:
            latents = generate_valid_latents(factors,messages,args.model_name,args.max_retries)

        except Exception as e:
            print('identify latent variables error: ',e)
            # Fallback: assign all factors to a single latent "Statement"
            latents = [{"name":"Statement", "factors": factors[:] }]

    # Ensure Statement latent exists and all factors are assigned
    names = [L["name"] for L in latents]
    if "Statement" not in names:
        latents.append({"name":"Statement","factors":[]})
        names.append("Statement")


    lower_to_key = {key.lower(): key for key in factor_statement_mapping.keys()}
    for latent in latents:
        latent["factors"] = [
            lower_to_key.get(f.lower(), f)
            for f in latent["factors"]
        ]

    seen = set()
    # First handle non-Statement latents
    for latent in latents:
        if latent["name"] == "Statement":
            continue
        unique = []
        for f in latent["factors"]:
            if f not in seen:
                unique.append(f)
                seen.add(f)
        latent["factors"] = unique

    # Then assign all unassigned factors (conflicts or omissions) to Statement
    stmt = next(L for L in latents if L["name"] == "Statement")
    for f in factors:
        # Use the same mapping key to ensure case consistency
        f_norm = lower_to_key.get(f.lower(), f)
        if f_norm not in seen:
            stmt["factors"].append(f_norm)
            seen.add(f_norm)

    edges = []
    for latent in latents:
        ln = latent["name"]
        for f in latent["factors"]:
                edges.append([ln, f])  # latent -> factor
        if ln != "Statement":
                edges.append([ln, "Statement"])  # latent -> Statement

    # ─── 2. Build network and add priors (dynamically detect parent nodes) ───
    bn = BayesianNetwork(edges)
    # Ensure Statement and all latents are in the node list
    for latent in latents:
        ln = latent["name"]
        if ln not in bn.nodes():
            bn.add_node(ln)
        for f in latent["factors"]:
            # Ensure child factor nodes are also in the model
            if f not in bn.nodes():
                bn.add_node(f)

    # For each latent, add either root priors or uniform conditional CPT based on whether it has parent nodes
    for latent in latents:
        ln = latent["name"]
        parents = bn.get_parents(ln)
        if parents:
            # Has parent nodes → add a uniform conditional CPT
            ncols = 2 ** len(parents)
            uniform = 1.0 / ncols
            values = [[uniform] * ncols, [uniform] * ncols]
            print(f"Adding uniform CPT for latent '{ln}' with parents {parents}")
            bn.add_cpds(
                TabularCPD(
                    ln, 2,
                    values,
                    evidence=parents,
                    evidence_card=[2] * len(parents)
                )
            )
        else:
            # No parent nodes → simple prior 0.5/0.5
            print(f"Adding prior CPD for root latent '{ln}' (no parents)")
            bn.add_cpds(
                TabularCPD(
                    ln, 2,
                    [[0.5], [0.5]]
                )
            )

    print("Before adding Statement CPT, bn.nodes() =", bn.nodes())


    prob_list = []
    if ('xsum' not in args.dataset_name
            and 'covid' not in args.dataset_name
                and 'cnn' not in args.dataset_name
                    and 'expert' not in args.dataset_name):
        for latent in latents:
            cnt_dict = {
                "Statement1":0,
                "Statement2":0,
                "Neutral":0
            }
            if latent['name'] == 'Statement':
                continue
            lower_factor_statement_mapping = {key.lower(): key for key in factor_statement_mapping}
            for factor in latent["factors"]:
                key = lower_factor_statement_mapping.get(factor.lower())
                if key:
                    cnt_dict[factor_statement_mapping[key]] += 1
                else:
                    print(f"Warning: Factor '{factor}' not found in factor_statement_mapping")
            latent_dict = {
                latent['name']:cnt_dict
            }
            prob_list.append(latent_dict)

                # ─── 3. Generate Statement's CPT ───
    latent_names = [L["name"] for L in latents if L["name"] != "Statement"]
    if latent_prob_maps:
        latent_prob_map = latent_prob_maps
        n = len(latent_names)
        true_row = []
        false_row = []
        prior = 0.5
        # 2. Enumerate all latent combinations
        for state in product([1, 0], repeat=n):
            like1 = 1.0
            like0 = 1.0
            for xi, latent in zip(state, latent_names):
                p1, p0 = latent_prob_map.get(latent, (0.5, 0.5))
                if xi == 1:
                    like1 *= p1
                    like0 *= p0
                else:
                    like1 *= (1 - p1)
                    like0 *= (1 - p0)
            post1 = prior * like1
            post0 = (1 - prior) * like0
            norm = post1 + post0
            if norm > 0:
                true_row.append(post1 / norm)
                false_row.append(post0 / norm)
            else:
                true_row.append(prior)
                false_row.append(1 - prior)
            cpt_statement = [true_row, false_row]
    else:
        if args.latent_prob_type == 'llm' and 'xsum' not in args.dataset_name and 'covid' not in args.dataset_name:
            cpt_statement,latent_prob_map = generate_cpt_from_stats(
                            latents=latents,
                        latent_names=latent_names,
                        statement1=statement1,
                        statement2=statement2,
                        prior=0.5,
                        args=args
                        )
        elif 'xsum' in args.dataset_name or 'covid' in args.dataset_name:
            cpt_statement,latent_prob_map = generate_cpt_from_stats(
                            latents=latents,
                        latent_names=latent_names,
                        statement1=statement1,
                        statement2='Not Statement1',
                        prior=0.5,
                        args=args
                        )
        else:
            cpt_statement,latent_prob_map = compute_cpt_from_stats(latent_names, prob_list, prior=0.5)

    assert set(latent_names) == set(bn.get_parents("Statement"))

    bn.add_cpds(
        TabularCPD(
            "Statement", 2,
            cpt_statement,
            evidence=latent_names,
            evidence_card=[2]*len(latent_names)
        )
    )
    print("After adding Statement CPD, bn.nodes() =", bn.nodes())

    # ─── 4. Helper: Construct two-column CPT matrix ───
    def make_cpd_matrix(pt, pf, parents):
        true_row, false_row = [], []
        for state in product([1,0], repeat=len(parents)):
            # Only use Statement state to decide whether to use pt or pf
            p = pt if ("Statement" in parents and state[parents.index("Statement")] == 1) else pf
            true_row.append(p)
            false_row.append(1-p)
        return [true_row, false_row]

    for latent in latents:
        parent = latent["name"]
        # if parent == "Statement":
        #     continue
        for f in latent["factors"]:
            raw_pt = p_true.get(f, 0.5)
            raw_pf = 1 - raw_pt
            pt = (raw_pt + args.SMOOTH_ALPHA) / (1 + 2 * args.SMOOTH_ALPHA)
            pf = (raw_pf + args.SMOOTH_ALPHA) / (1 + 2 * args.SMOOTH_ALPHA)
            cpd = TabularCPD(
                f, 2,
                make_cpd_matrix(pt, pf, [parent]),
                evidence=[parent],
                evidence_card=[2]
            )

            # —— DEBUG OUTPUT ——
            # print(f"\nAttempting to add CPD for factor '{f}' with parent '{parent}'")
            # print("  bn.nodes() =", bn.nodes())
            # print("  CPD.variable  =", cpd.variable)

            try:
                bn.add_cpds(cpd)
            except ValueError as e:
                print("ERROR adding CPD:", e)
                # Print complete CPD and network status to help locate the issue
                print("CPD repr:", cpd)
                print("Current bn.nodes():", bn.nodes())
                raise
    # 4. Final sanity check
    declared = {cpd.variable for cpd in bn.get_cpds()}
    expected = {L["name"] for L in latents} | {"Statement"} | set(sum((L["factors"] for L in latents), []))
    missing = sorted(expected - declared)
    if missing:
        print(f"No CPD associated with nodes: {missing}")
    bn.check_model()
    return bn, edges, latents,latent_prob_map

def generate_probability_map(
    rec: Dict[str, Any],
    flat_facs: List[str],
    generate_para_prob_shots: List[Dict[str, str]],
    args
) -> Dict[str, float]:
    """
    Call LLM to generate and validate factor-probability mapping, retry until all factor keys match the original flat_facs list (case-insensitive).
    If multiple retries still fail, return the initial p_true_map (uniform distribution).

    Args:
        rec: Record dict containing 'scenario', 'statement', 'opposite_statement'
        flat_facs: Original factor list
        generate_para_prob_shots: Pre-set few-shot message list
        model_name: Model name to call
        max_retries: Maximum number of retries

    Returns:
        A dict mapping each factor to a probability value
    """
    # Build initial prompt
    prior_prompt = [{
        "role": "user",
        "content": (
            f"Given the scenario: {rec['scenario']}, "
            f"For each of the following factor values, please estimate the probability "
            f"(a float between 0 and 1) that it supports Statement 1 {rec['statement']} "
            # f"rather than Statement 2 {rec['opposite_statement']}. "
            f"Return a JSON mapping from factor value to probability.\n"
            f"Factor values: {flat_facs}"
        )
    }]
    prior_message = generate_para_prob_shots + prior_prompt

    # Initialize p_true_map as uniform distribution
    p_true_map: Dict[str, float] = {f: 1.0 / len(flat_facs) for f in flat_facs}

    attempt = 1
    while attempt <= args.max_retries:
        raw_resp = ask_gpt(prior_message, model_name=args.model_name,use_temp=args.use_temp, max_token=2048)
        print('[generate_probability_map]------>>>>> resp: ', raw_resp)
        try:
            prior_resp = safe_json_parse(raw_resp)
            if not isinstance(prior_resp, dict):
                raise ValueError("Response is not a JSON object")

            expected_set = set(flat_facs)
            resp_keys = set(prior_resp.keys())

            missing = expected_set - resp_keys
            extra = resp_keys - expected_set

            if missing or extra:
                msgs = []
                if missing:
                    msgs.append(f"Missing factor values: {sorted(missing)}")
                if extra:
                    msgs.append(f"Unexpected factor values: {sorted(extra)}")
                error_msg = "; ".join(msgs)

                print(f"[Attempt {attempt}] Key mismatch: {error_msg}")
                prior_message.append({
                    "role": "user",
                    "content": (
                        f"The following issues were found in your JSON mapping: {error_msg}. "
                        f"Please regenerate the JSON using exactly the original factor values: {flat_facs}."
                    )
                })
                attempt += 1
                continue

            # Validation passed, update p_true_map and return
            p_true_map = {k: float(v) for k, v in prior_resp.items()}
            return p_true_map

        except Exception as e:
            print(f"[Attempt {attempt}] JSON parse or validation error: {e}")
            prior_message.append({
                "role": "user",
                "content": (
                    f"Your response could not be parsed or validated: {e}. "
                    f"Please return a JSON mapping from the original factor values {flat_facs} to probabilities."
                )
            })
            attempt += 1

    # All attempts failed, return initial uniform distribution p_true_map
    print(f"Warning: failed after {args.max_retries} attempts, returning initial uniform p_true_map.")
    return p_true_map


def support_prob(engine, subset):
    ev = {f: 1 for f in subset}
    q = engine.query(["Statement"], evidence=ev, show_progress=False)
    return q.values[1]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default=config.model_name)
    parser.add_argument("--dataset_name", type=str, default=config.dataset_name)
    parser.add_argument("--SMOOTH_ALPHA", type=float, default=0.5)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=1000)
    parser.add_argument("--dataset_file_dic", type=str, default=config.dataset_file_dic)
    parser.add_argument("--save_file_dic", type=str, default=config.save_file_dic)
    parser.add_argument("--use_bn", action="store_true")
    parser.add_argument("--max_retries", type=int, default=20, help="Maximum retries for LLM response parsing")
    parser.add_argument("--use_cot", type=str, default='_cot_v1_')
    parser.add_argument("--use_temp", type=str, default=0.5)
    parser.add_argument("--factor_prob_type", type=str, default='llm_')
    parser.add_argument("--latent_prob_type", type=str, default='llm')
    # parser.add_argument("--mapping_type", type=str, default='init_cluster')
    parser.add_argument("--mapping_type", type=str, default='pruned_cluster')
    # parser.add_argument("--mapping_type", type=str, default='uncluster')
    parser.add_argument("--ablation", type=str, default='false')
    args = parser.parse_args()
    print_args(args)
    return args




if __name__ == '__main__':
    args = parse_args()
    with open(f"{args.dataset_file_dic}{args.dataset_name}.json") as f:
        df_origin = json.load(f)
    suffix = ''
    model_name = args.model_name
    # Load structure and mapping
    factor_path = os.path.join(args.save_file_dic,f"{args.dataset_name}_{model_name}_0_{args.end}_factors{suffix}.json")
    # factor_path = os.path.join(args.save_file_dic,f"{args.dataset_name}_{args.model_name.replace(':', '-')}_0_1000_factors{suffix}.json")

    df_factors = json.load(open(factor_path))
    print('input factors file------------>: ', factor_path)
    mapping_path = os.path.join(
        args.save_file_dic,
        f"{args.dataset_name}_{model_name}_{args.start}_1000_condition_mapping{suffix}.json"
        # f"{args.dataset_name}_{args.model_name.replace(':', '-')}_{args.start}_1000_condition_mapping{suffix}.json"
    )
    print('input mapping file------------>: ', mapping_path)
    df_factor = json.load(open(mapping_path, 'r', encoding='utf-8'))
    # Output structure
    out = []
    # out_path = f"{args.save_file_dic}{args.dataset_name}_{args.model_name.replace(':', '-')}_bn_compare_add_both{args.use_cot}{args.factor_prob_type}{args.latent_prob_type}.json"
    out_path = (f"{args.save_file_dic}{args.dataset_name}_{model_name}_{args.model_name.replace(':', '-')}"
                f"_bn_compare_add_both{args.use_cot}{args.factor_prob_type}{args.latent_prob_type}_"
                f"{args.mapping_type}.json")
    print('output file------------>: ', out_path)
    # Resume from checkpoint: if output already exists, read it and start from the next record

    para_path = 'ablation/results/raw/init_cluster/common2sense_qwen2.5-72b_qwen2.5-72b_bn_compare_add_both_cot_v1_llm_llm_init_cluster.json'
    print('input para file------------>: ', para_path)
    df_para = json.load(open(para_path, 'r', encoding='utf-8'))


    if os.path.exists(out_path):
        with open(out_path, 'r', encoding='utf-8') as f:
            out = json.load(f)
        start_idx = len(out)
        print(f"🔄 Resuming from record #{start_idx + 1}")
    else:
        out = []
        start_idx = 0
        print("🚀 Starting fresh run")
    total = len(df_origin)


    for i in tqdm(range(start_idx, total), desc="Processing scenarios"):
        if i < args.start:
            continue
        if i >= args.end:
            break
        rec = df_origin[i]
        print(f"🚀 =====Processing scenario {i + 1}/{len(df_origin)}: {rec['scenario']}, statement: {rec['statement']}=====")
        key = rec['scenario'] + rec['statement']
        factor_statement_mapping = next(d for d in df_factors if d['scenario']+d['statement']==key)['factor_statement_mapping']
        factors = next(d for d in df_factors if d['scenario']+d['statement']==key)
        if args.mapping_type == 'init_cluster':
            mapping = next(d for d in df_factor if d['scenario']+d['statement']==key)['initial_clustered_mapping_results']
        elif args.mapping_type == 'pruned_cluster':
            mapping = next(d for d in df_factor if d['scenario']+d['statement']==key)['pruned_clustered_mapping_results']
        else:
            mapping = next(d for d in df_factor if d['scenario']+d['statement']==key)['unclustered_clustered_mapping_results']
            for item in mapping.values():
                if 'final_factors' in item:
                    item['final_factors'] = [factor.lower() for factor in item['final_factors']]
        if args.ablation == 'true':
            para_prob = next(d for d in df_para if d['scenario']+d['statement']==key)['para_prob']
            causal_edges = next(d for d in df_para if d['scenario']+d['statement']==key)['causal_edges']
            latents = next(d for d in df_para if d['scenario']+d['statement']==key)['latents']
            latent_prob_map = next(d for d in df_para if d['scenario']+d['statement']==key)['latent_prob_map']
            flat_facs = [f for f in factor_statement_mapping if f in factors['factors_after_clustering']
                         and f in para_prob]
        else:
            flat_facs = [f for f in factor_statement_mapping if f in factors['factors_after_clustering']]

        # Generate para_prob through LLM
        for condition, value in mapping.items():
            temp = []
            for factor in value["final_factors"]:
                if args.ablation == 'true':
                    if factor in flat_facs and factor in para_prob:
                        temp.append(factor)
                else:
                    if factor in flat_facs :
                        temp.append(factor)
            mapping[condition]["final_factors"] = temp

        # 5. Call and get results
        try:
            if 'cot' in args.use_cot:
                if 'v1' in args.use_cot:
                    shots = generate_para_prob_shots_cot_v1
                else:
                    shots = generate_para_prob_shots_cot_v0
            else:
                shots = generate_para_prob_shots
            if args.ablation == 'true':
                p_true_map = para_prob
            elif 'stat' in args.factor_prob_type:
                p_true_map = {}
                for factor in flat_facs:
                    if factor_statement_mapping[factor] == 'Statement1':
                        p_true_map[factor] = round(random.uniform(0.45, 0.85), 2)
                    elif factor_statement_mapping[factor] == 'Statement2':
                        p_true_map[factor] = round(random.uniform(0.15, 0.55), 2)
                    else:
                        p_true_map[factor] = 0.5
            else:

                p_true_map = generate_probability_map(
                    rec=rec,
                    flat_facs=flat_facs,
                    generate_para_prob_shots=shots,
                    args=args,
                )
        except RuntimeError as e:
            print(f"Call failed: {e}")

        # Build NB and CBN
        nb = build_naive_bn(flat_facs, p_true_map,args.SMOOTH_ALPHA)
        if args.ablation == 'true':
            cbn, edges, latents,latent_prob_map = build_causal_bn(args=args,
                                        factors=flat_facs,
                                         p_true=p_true_map,
                                        factor_statement_mapping=factor_statement_mapping,
                                                  statement1=rec['statement'],
                                                  statement2=rec['opposite_statement'],
                                                    latent = latents,
                                                    latent_prob_maps=latent_prob_map
                                                  )
        else:
            cbn, edges, latents,latent_prob_map = build_causal_bn(args=args,
                                        factors=flat_facs,
                                         p_true=p_true_map,
                                        factor_statement_mapping=factor_statement_mapping,
                                                  statement1=rec['statement'],
                                                  # statement2=rec['opposite_statement']
                                                  )
        inf_nb = VariableElimination(nb)
        inf_cbn = VariableElimination(cbn)
        # Compare results
        results = {}
        for sent, value in mapping.items():
            factors = value['final_factors'].copy()
            # Determine the number of elements to extract (3-5, but not exceeding list length)
            # sample_size = min(len(factors), random.randint(5, 9)) if len(factors) >= 3 else len(factors)
            # sample_size =  min(len(factors),3)
            # sampled_factors = random.sample(factors, sample_size) if factors else []


            results[sent] = {
                'nb': support_prob(inf_nb, value['final_factors']),
                # 'nb': support_prob(inf_nb, sampled_factors ),
                # 'nb': support_prob(inf_nb, top_three_factor_names ),
                'cbn': support_prob(inf_cbn, value['final_factors']),
                'mapped_factors': value['final_factors']
                # 'mapped_factors': sampled_factors
            }
        out.append({
            'scenario': rec['scenario'],
            'statement': rec['statement'],
            'condition_factor_mapping': mapping,
            'causal_edges': edges,
            'latents': latents,
            'latent_prob_map': latent_prob_map,
            'para_prob': p_true_map,
            'results': results
        })
        with open(out_path, 'w', encoding='utf-8') as f:
            json.dump(out, f, indent=4, ensure_ascii=False)

    # Save comparison


    print(f"Saved comparison to {out_path}")
