
import os
import json
import time
import argparse
import re
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import umap.umap_ as umap
import hdbscan
from sentence_transformers import SentenceTransformer

# Ensure project path
from utils.prompt_shots_fact_check import FEW_SHOT_PRUNE_XSUM, FEW_SHOT_VOTE_XSUM, \
    FEW_SHOT_GEN_XSUM, FEW_SHOT_EXT_XSUM, FEW_SHOT_CLUSTER_NAME_XSUM, FEW_SHOT_EXT_COVID, FEW_SHOT_GEN_COVID, \
    FEW_SHOT_VOTE_COVID, FEW_SHOT_CLUSTER_NAME_COVID, FEW_SHOT_PRUNE_COVID, FEW_SHOT_GEN_CNN, FEW_SHOT_GEN_EXPERTQA, \
    FEW_SHOT_EXT_EXPERTQA, FEW_SHOT_EXT_CNN, FEW_SHOT_VOTE_CNN, FEW_SHOT_VOTE_EXPERTQA, FEW_SHOT_CLUSTER_NAME_CNN, \
    FEW_SHOT_CLUSTER_NAME_EXPERTQA, FEW_SHOT_PRUNE_EXPERTQA, FEW_SHOT_PRUNE_CNN

from utils import config
from utils.utils import ask_gpt, print_args
from utils.prompt_shots import JSON_CORRECTION_PROMPT
from utils.response_parse_utils import safe_json_parse,extract_choice_from_response



class IterativeFactorGenerator:
    def __init__(self, model_name, target_count=100, batch_size=10, max_rounds=20, do_cluster=False):
        self.model = model_name
        self.target_count = target_count
        self.batch_size = batch_size
        self.max_rounds = max_rounds
        self.do_cluster = do_cluster
        self.factors = set()
        self.init_factor_label = {}
        self.max_retries=5
        # Store additional information for output
        self.generated_sentences = []
        self.voting_records = {}
        self.pruning_records = {}  # Store pruning raw responses
        
        # Initialize embedding model for clustering
        if do_cluster:
            print("[Init] Loading embedding model for clustering...")
            try:
                # Try local path first
                local_path = config.MiniLM_L6_PATH
                if os.path.exists(local_path) and os.path.isdir(local_path):
                    print(f"[Init] Trying local model path: {local_path}")
                    self.embedding_model = SentenceTransformer(local_path)
                else:
                    raise FileNotFoundError("Local model not found")
            except Exception as e:
                print(f"[Init] Local model failed ({e}), downloading from Hugging Face...")
                # Fallback to downloading from Hugging Face
                self.embedding_model = SentenceTransformer(config.MiniLM_L6_PATH)
            print("[Init] Embedding model loaded successfully")

    def generate_batch_sentences_fact_check(self, statement, n):
        print(f"\n[Generate] Generating {n} sentences for statement:")
        print(f"  Statement: {statement}\n")

        msgs = [
            {
                'role': 'system',
                'content': (
                    "You are a fact-checking assistant. "
                    "When given a Scenario and Statement, generate a list of sentences based on it: "
                    "half should be Consistent with the statement’s facts and half should contradict them. "
                    "Number each sentence and do not add any extra commentary."
            )
            }
        ]
        # Insert XSUM-style few-shot examples
        if 'xsum' in config.dataset_name:
            msgs += FEW_SHOT_GEN_XSUM
        elif 'covid' in config.dataset_name:
            msgs += FEW_SHOT_GEN_COVID
        elif 'cnn' in config.dataset_name:
            msgs += FEW_SHOT_GEN_CNN
        elif 'expert' in config.dataset_name:
            msgs += FEW_SHOT_GEN_EXPERTQA

        # User prompt now only includes the combined statement
        msgs.append({
            'role': 'user',
            'content': (
                f"Scenario and Statement: {statement}"
                f"Generate {n} sentences that are Consistent with this statement and {n} that contradict it."
            )
        })

        # Call the model
        resp = ask_gpt(msgs, model_name=self.model, max_token=512)

        # Split into lines and clean
        lines = [line.strip() for line in resp.split("\n") if line.strip()]
        print(f"[Generate] Sentences:\n  " + "\n  ".join(lines))

        # Store and return
        self.generated_sentences.extend(lines)
        return lines

    def extract_factors(self, sentences):
        print("[Extract] Extracting key factors from sentences...")
        base_msgs = [
            {
                'role': 'system',
                'content': (
                    "You are a fact-checking assistant. When given a list of sentences about a Scenario and Statement—"
                    "half of which support the Statement and half contradict it—extract each distinct key factor mentioned. "
                    "Label each factor as “Consistent” if it aligns with the Statement’s facts, or “Contradictory” if it conflicts. "
                    "Then output **only** a single JSON object with exactly two keys, “Consistent” and “Contradictory”, each mapping to an array of factor strings. "
                    "Do not include any additional text or commentary."
                )
            }
        ]

        # few-shot examples for extraction
        if 'xsum' in config.dataset_name:
            base_msgs += FEW_SHOT_EXT_XSUM
        elif 'covid' in config.dataset_name:
            base_msgs += FEW_SHOT_EXT_COVID
        elif 'cnn' in config.dataset_name:
            base_msgs += FEW_SHOT_EXT_CNN
        elif 'expert' in config.dataset_name:
            base_msgs += FEW_SHOT_EXT_EXPERTQA

        # user prompt with the sentences to analyze
        user_content = (
                "Based on the Scenario and Statement, extract and label key factors from these sentences:\n\n" +
                "\n".join(f"{i + 1}. {s}" for i, s in enumerate(sentences))
        )
        base_msgs.append({'role': 'user', 'content': user_content})

        last_resp = ""
        for attempt in range(1, self.max_retries + 1):
            msg_copy = list(base_msgs)
            if attempt > 1:
                # include previous model response and correction prompt
                msg_copy.append({'role': 'assistant', 'content': last_resp})
                msg_copy.append(JSON_CORRECTION_PROMPT)
                msg_copy.append({'role': 'user', 'content': user_content})

            resp = ask_gpt(msg_copy, model_name=self.model, max_token=1024)
            last_resp = resp
            factor_dict = safe_json_parse(resp, default_value={}, context="Factor extraction")

            # check that result is a dict with both keys
            if isinstance(factor_dict, dict) and \
                    'Consistent' in factor_dict and \
                    'Contradictory' in factor_dict:
                print(f"[Extract] Successful on attempt {attempt}")
                return factor_dict
            else:
                print(f"[Extract] Attempt {attempt} failed or missing keys, retrying...")

        # After retries, ensure keys exist
        print("[Extract] All retries exhausted; filling missing keys with empty lists")
        if not isinstance(factor_dict, dict):
            factor_dict = {}
        factor_dict.setdefault('Consistent', [])
        factor_dict.setdefault('Contradictory', [])
        return factor_dict

    def vote_single_factor_with_retry_fact_check(self, factor, scenario, statement):
        """
        Vote on a single factor: determine if it supports (Consistent) or contradicts (Contradictory) the given statement.
        Includes retry logic for JSON parsing failures.
        """
        # Copy few-shot examples
        if 'xsum' in config.dataset_name:
            msgs = FEW_SHOT_VOTE_XSUM.copy()
        elif 'covid' in config.dataset_name:
            msgs = FEW_SHOT_VOTE_COVID.copy()
        elif 'cnn' in config.dataset_name:
            msgs = FEW_SHOT_VOTE_CNN.copy()
        elif 'expert' in config.dataset_name:
            msgs = FEW_SHOT_VOTE_EXPERTQA.copy()
        # Insert new system prompt
        msgs.insert(0, {
            'role': 'system',
            'content': (
                "Decide whether the given factor is Consistent or Contradictory towards the statement. "
                "Give a very brief reasoning (1–2 sentences), then output a JSON object with the factor as the key "
                "and one of the values \"Consistent\", \"Contradictory\". "
                "No extra commentary."
            )
        })

        # Add user prompt, only including scenario, statement, factor
        msgs.append({
            'role': 'user',
            'content': (
                f"Scenario and Statement: {statement}\n"
                f"Factor: {factor}\n"
                "Please decide if the factor is Consistent or Contradictory to the statement, "
                "and respond with a JSON object like {"
                f"\"{factor}\": \"Contradictory\""
                "}."
            )
        })

        for retry in range(1, self.max_retries + 1):
            resp = ask_gpt(msgs, model_name=self.model, max_token=256)
            print(f"[Vote] Response for factor '{factor}' (attempt {retry}): {resp}")

            parsed_response = safe_json_parse(resp, default_value={},
                                     context=f"Voting for factor '{factor}', attempt {retry}")
            if isinstance(parsed_response, dict) and factor in parsed_response:
                choice = extract_choice_from_response(parsed_response, factor)
                if choice in ["Consistent", "Contradictory"]:
                    return choice, resp

            # Add correction prompt after second failure
            if retry == 2:
                print(f"[Vote] JSON parsing failed, retrying with correction prompt...")
                msgs.append({'role': 'assistant', 'content': resp})
                msgs.append(JSON_CORRECTION_PROMPT)
                msgs.append({
                    'role': 'user',
                    'content': (
                        "Please return a valid JSON object with the format "
                        "{\"<factor>\": \"Consistent\" | \"Contradictory\" }."
                    )
                })

        # Default to Neutral after multiple retries
        print(f"[Vote] All retries failed for factor '{factor}', defaulting to 'Neutral'")
        return "Neutral", resp

    def vote_factors(self, factors, scenario, stmt1, stmt2):
        print("[Vote] Voting on factor Consistent with majority vote (3 calls per factor)...")
        mapping = {}
        conversion_log = {}  # Track Both→Neutral conversions

        for f in tqdm(factors, desc="Voting factors", unit="factor"):
            votes = []
            vote_details = []

            for attempt in range(3):
                choice, resp = self.vote_single_factor_with_retry_fact_check(f, scenario, stmt1)

                votes.append(choice)
                vote_details.append({
                    'attempt': attempt + 1,
                    'response': resp,
                    'choice': choice
                })

                print(f"[Vote] Voting time {attempt + 1} for factor '{f}' => {choice}")
                # if this is the second vote (attempt == 1) and it matches the first, stop early
                if attempt == 1 and votes[0] == votes[1]:
                    print(f"[Vote] First two attempts agree ('{votes[0]}'), ending early.")
                    break


            # Majority vote
            final_choice = votes[0] if len(votes) >= 2 and votes[0] == votes[1] else votes[-1]

            # Convert "Both" to "Neutral" as per requirement
            original_choice = final_choice
            if final_choice == 'Both':
                final_choice = 'Neutral'
                conversion_log[f] = {
                    'original': 'Both',
                    'converted_to': 'Neutral',
                    'votes': votes
                }
                print(f"[Vote] Factor '{f}' converted: Both → Neutral")

            print(f"[Vote] Factor '{f}' final choice: {final_choice} (original: {original_choice})")
            mapping[f] = final_choice

            # Store voting records
            self.voting_records[f] = {
                'votes': votes,
                'original_choice': original_choice,
                'final_choice': final_choice,
                'vote_details': vote_details,
                'was_converted': f in conversion_log
            }

        # Log conversion summary
        if conversion_log:
            print(f"[Vote] Conversion Summary: {len(conversion_log)} factors converted from 'Both' to 'Neutral'")
            for factor, info in conversion_log.items():
                print(f"  - {factor}: votes={info['votes']}")

        return mapping

    def cluster_factors(self, factors):
        print("[Cluster] Clustering factors using UMAP + HDBSCAN...")
        factor_list = list(factors)

        if len(factor_list) < 2:
            print("[Cluster] Too few factors for clustering, returning single cluster")
            return {'cluster_0': factor_list}

        try:
            # Step 1: Generate embeddings using all-MiniLM-L6-v2
            print(f"[Cluster] Generating embeddings for {len(factor_list)} factors...")
            embeddings = self.embedding_model.encode(factor_list, show_progress_bar=True)
            print(f"[Cluster] Embeddings shape: {embeddings.shape}")
            del self.embedding_model
            import gc
            gc.collect()
            # Step 2: UMAP dimensionality reduction
            # Reduce to 10-50 dimensions, dynamically choose based on factor count
            n_factors = len(factor_list)
            umap_dims = min(50, max(10, n_factors // 5))  # 10-50 dims, scale with factor count
            umap_dims = min(umap_dims, n_factors - 1)  # Cannot exceed n_factors - 1

            print(f"[Cluster] Applying UMAP reduction to {umap_dims} dimensions...")
            umap_reducer = umap.UMAP(
                n_components=umap_dims,
                n_neighbors=min(15, n_factors - 1),  # Adjust neighbors based on factor count
                min_dist=0.1,
                metric='cosine',
                random_state=42
            )
            reduced_embeddings = umap_reducer.fit_transform(embeddings)
            print(f"[Cluster] UMAP reduced embeddings shape: {reduced_embeddings.shape}")

            # Step 3: HDBSCAN clustering
            min_cluster_size = max(2, n_factors // 20)  # ~5% of factors, minimum 2
            print(f"[Cluster] Applying HDBSCAN with min_cluster_size={min_cluster_size}...")
            clusterer = hdbscan.HDBSCAN(
                min_cluster_size=min_cluster_size,
                min_samples=1,
                metric='euclidean',
                cluster_selection_epsilon=0.0
            )
            cluster_labels = clusterer.fit_predict(reduced_embeddings)

            # Organize factors by cluster
            clusters_dict = defaultdict(list)
            noise_factors = []

            for factor, label in zip(factor_list, cluster_labels):
                if label == -1:  # Noise point
                    noise_factors.append(factor)
                else:
                    clusters_dict[f'cluster_{label}'].append(factor)

            # Add noise as separate cluster if exists
            if noise_factors:
                clusters_dict['noise'] = noise_factors

            print(f"[Cluster] Found {len(clusters_dict)} clusters:")
            for cluster_name, cluster_factors in clusters_dict.items():
                print(f"  - {cluster_name}: {len(cluster_factors)} factors")

            # Step 4: Use LLM to generate meaningful cluster names
            final_clusters = {}
            for cluster_name, cluster_factors in clusters_dict.items():
                if cluster_name == 'noise':
                    final_clusters['uncategorized'] = cluster_factors
                else:
                    # Generate meaningful cluster name using LLM
                    theme_name = self._generate_cluster_theme(cluster_factors)
                    final_clusters[theme_name] = cluster_factors

            print(f"[Cluster] Final clusters with themes: {list(final_clusters.keys())}")
            return final_clusters

        except Exception as e:
            print(f"[Cluster] Error in UMAP+HDBSCAN clustering: {e}")
            print("[Cluster] Falling back to single cluster")
            return {'all_factors': factor_list}

    def _generate_cluster_theme(self, cluster_factors):
        """Generate a meaningful theme name for a cluster using LLM"""
        if len(cluster_factors) < 2:
            return f"theme_{hash(tuple(sorted(cluster_factors))) % 10000}"

        try:
            if 'xsum' in config.dataset_name:
                msgs = FEW_SHOT_CLUSTER_NAME_XSUM.copy()
            elif 'covid' in config.dataset_name:
                msgs = FEW_SHOT_CLUSTER_NAME_COVID.copy()
            elif 'cnn' in config.dataset_name:
                msgs = FEW_SHOT_CLUSTER_NAME_CNN.copy()
            elif 'expert' in config.dataset_name:
                msgs = FEW_SHOT_CLUSTER_NAME_EXPERTQA.copy()
            msgs.insert(0, {'role': 'system',
                            'content': 'Generate a concise English theme name (1-3 words) that captures the common topic of these factors. Return only the theme name, no explanation.'},
                )
            msgs.append(
           {'role': 'user',
            'content': f"Generate a theme name for these related factors:\n{json.dumps(cluster_factors[:10], ensure_ascii=False)}"}  # Limit to first 10 for context
            )

            resp = ask_gpt(msgs, model_name=self.model, max_token=50)
            theme = resp.strip().strip('"').strip("'")

            # Remove special tokens and unwanted patterns
            special_tokens = [
                'eot_id', 'start_header_id', 'end_header_id', 'user', 'assistant', 
                'system', 'bot_id', 'human_id', '<|', '|>', '</', '/>', '###'
            ]
            
            # First, remove any content after special tokens
            for token in special_tokens:
                if token in theme:
                    theme = theme.split(token)[0]
                    break
            
            # Clean up theme name - keep only alphanumeric, spaces, underscores, and hyphens
            theme = re.sub(r'[^a-zA-Z0-9\s_-]', '', theme)
            theme = '_'.join(theme.split())
            
            # Remove any remaining underscores at start/end and limit length
            theme = theme.strip('_')[:50]

            if not theme or len(theme) < 2:
                theme = f"cluster_{hash(tuple(sorted(cluster_factors))) % 10000}"

            print(f"[Cluster] Generated theme '{theme}' for {len(cluster_factors)} factors")
            return theme

        except Exception as e:
            print(f"[Cluster] Failed to generate theme name: {e}")
            return f"cluster_{hash(tuple(sorted(cluster_factors))) % 10000}"

    def prune_redundancy(self, clusters, mapping):
        print("[Prune] Pruning redundant factors in each cluster...")
        pruned = []
        pruned_clusters = {}  # Store clusters after pruning
        
        for cname, facs in clusters.items():
            print(f"[Prune] Cluster '{cname}': {facs}")
            if 'xsum' in config.dataset_name:
                msgs = FEW_SHOT_PRUNE_XSUM.copy()
            elif 'covid' in config.dataset_name:
                msgs = FEW_SHOT_PRUNE_COVID.copy()
            elif 'cnn' in config.dataset_name:
                msgs = FEW_SHOT_PRUNE_CNN.copy()
            elif 'expert' in config.dataset_name:
                msgs = FEW_SHOT_PRUNE_EXPERTQA.copy()
            msgs.insert(0, {'role': 'system', 'content': f"From these factors in cluster '{cname}', remove redundant ones. Think about which factors are essentially the same concept, then return a JSON array of the remaining unique factors."})
            msgs.append({'role': 'user', 'content': (
                f"From these factors in cluster '{cname}', remove redundant ones. "
                f"Think about which factors are essentially the same concept, then return a JSON array of the remaining unique factors.\n\n"
                f"{json.dumps(facs, ensure_ascii=False)}")})
            resp = ask_gpt(msgs, model_name=self.model, max_token=512)

            # Store the raw response for this cluster
            self.pruning_records[cname] = {
                'original_factors': facs,
                'raw_response': resp,
                'prompt_used': msgs[-1]['content']
            }

            # Use unified JSON parser
            keep = safe_json_parse(resp, default_value=facs, context=f"Pruning cluster '{cname}'")
            if not isinstance(keep, list):
                print(f"[Prune] Warning: Expected list but got {type(keep)}, using original factors")
                keep = facs
            
            # Update pruning record with parsed result
            self.pruning_records[cname]['parsed_factors'] = keep
            self.pruning_records[cname]['factors_removed'] = list(set(facs) - set(keep))
            self.pruning_records[cname]['factors_kept_count'] = len(keep)
            self.pruning_records[cname]['factors_removed_count'] = len(set(facs) - set(keep))
            
            print(f"[Prune] Kept after pruning: {keep}")
            pruned += keep
            
            # Store pruned cluster (only if it has factors left)
            if keep:
                pruned_clusters[cname] = keep
                
        return set(pruned), pruned_clusters

    def run(self, scenario, statement, oppo_statement=''):
        print(f"\n=== Start instance: '{scenario}' ===")
        for rnd in tqdm(range(self.max_rounds), desc="Gen rounds", unit="round"):
            if len(self.factors) >= self.target_count:
                break
            # sents = self.generate_batch_sentences(scenario, statement, oppo_statement, self.batch_size)
            sents = self.generate_batch_sentences_fact_check( statement, self.batch_size)
            newf = self.extract_factors(sents)
            temp_factors = set()
            Consistent_factor_dict = newf['Consistent']
            for factor in Consistent_factor_dict:
                self.init_factor_label[factor] = 'Consistent'
                temp_factors.add(factor)
            Contradictory_factor_dict = newf['Contradictory']
            for factor in Contradictory_factor_dict:
                self.init_factor_label[factor] = 'Contradictory'
                temp_factors.add(factor)

            before = len(self.factors)
            self.factors |= temp_factors
            after = len(self.factors)
            print(f"[Round] {rnd + 1}: +{after - before} factors, total={after}")
        print(f"[Result] Collected {len(self.factors)} factors.")

        # Voting with Both→Neutral conversion
        # mapping = self.vote_factors(self.factors, scenario, statement, oppo_statement)
        mapping = self.init_factor_label.copy()
        factors_after_voting = {f for f, v in mapping.items() if v in ('Consistent', 'Contradictory')}
        # print(f"[Vote] After filter: {len(factors_after_voting)} factors remain.")

        # Store factors before clustering
        factors_before_clustering = list(factors_after_voting)

        # Create factor mapping for final results (factor -> statement)
        final_factor_mapping = {f: v for f, v in mapping.items() if f in factors_after_voting}

        # Optional clustering prune
        factors_after_clustering = factors_after_voting
        initial_clusters = None
        pruned_clusters = None  # Initialize pruned_clusters
        clustering_stats = {
            'applied': self.do_cluster,
            'factors_before_count': len(factors_before_clustering),
            'factors_after_count': 0,
            'factors_removed_count': 0,
            'factors_removed': [],
            'reduction_rate': 0.0,
            'initial_clusters': None,
            'cluster_count': 0
        }

        if self.do_cluster:
            initial_clusters = self.cluster_factors(factors_after_voting)
            factors_after_clustering, pruned_clusters = self.prune_redundancy(initial_clusters, mapping)
            print(f"[Prune] After clustering prune: {len(factors_after_clustering)} factors remain.")

            # Calculate clustering statistics
            factors_removed = set(factors_before_clustering) - set(factors_after_clustering)
            clustering_stats.update({
                'factors_after_count': len(factors_after_clustering),
                'factors_removed_count': len(factors_removed),
                'reduction_rate': len(factors_removed) / len(factors_before_clustering) if factors_before_clustering else 0.0,
                'cluster_count': len(initial_clusters) if initial_clusters else 0,
                'factors_removed': list(factors_removed),
                'initial_clusters': initial_clusters
            })

            print(f"[Clustering Stats] Before: {clustering_stats['factors_before_count']}, "
                  f"After: {clustering_stats['factors_after_count']}, "
                  f"Removed: {clustering_stats['factors_removed_count']}, "
                  f"Reduction: {clustering_stats['reduction_rate']:.2%}, "
                  f"Clusters: {clustering_stats['cluster_count']}")
        else:
            clustering_stats['factors_after_count'] = len(factors_before_clustering)

        print(f"=== End instance: '{scenario}', final factors count: {len(factors_after_clustering)} ===\n")

        # Return comprehensive results with both factor sets and statistics
        return {
            'factors_before_clustering': factors_before_clustering,
            'factors_after_clustering': list(factors_after_clustering),
            'factor_statement_mapping': final_factor_mapping,  # factor -> statement mapping
            # 'init_factor_label': self.init_factor_label,  # factor -> statement mapping
            'generated_sentences': self.generated_sentences,
            # 'voting_records': self.voting_records,
            'pruning_records': self.pruning_records,  # Add pruning raw responses
            'clustering_applied': self.do_cluster,
            'clustering_stats': clustering_stats,
            'pruned_clusters': pruned_clusters
        }


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default=config.model_name)
    parser.add_argument("--dataset_name", type=str, default=config.dataset_name)
    parser.add_argument("--dataset_file_dic", type=str, default=config.dataset_file_dic)
    parser.add_argument("--save_file_dic", type=str, default=config.save_file_dic)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=1000)
    # parser.add_argument("--target", type=int, default=20)
    # parser.add_argument("--batch", type=int, default=5)
    # parser.add_argument("--rounds", type=int, default=10)
    parser.add_argument("--target", type=int, default=40)
    parser.add_argument("--batch", type=int, default=6)
    parser.add_argument("--rounds", type=int, default=10)
    parser.add_argument("--cluster", action='store_true', help="enable clustering prune")
    parser.add_argument("--no-cluster", dest='cluster', action='store_false', help="disable clustering prune")
    parser.set_defaults(cluster=True)
    args = parser.parse_args()
    print_args(args)
    return args


if __name__ == '__main__':
    args = parse_args()
    infile = os.path.join(args.dataset_file_dic, args.dataset_name + '.json')
    print('input file--------->: {}'.format(infile))
    with open(infile, 'r', encoding='utf-8') as f:
        data = json.load(f)
    # suffix = '_early_stop_vote'
    suffix = ''
    os.makedirs(args.save_file_dic, exist_ok=True)
    out_objs = []
    out_path = os.path.join(
        args.save_file_dic,
        f"{args.dataset_name}_{args.model_name.replace(':', '-')}_{args.start}_{args.end}_factors{suffix}.json"
    )
    print('output file---------->: {}'.format(out_path))
    if os.path.exists(out_path):
        with open(out_path, 'r', encoding='utf-8') as f:
            out_objs = json.load(f)
        start_index = args.start + len(out_objs)  # Continue from the number of existing records

    # Ensure it does not exceed the end index
    if start_index >= args.end:
        print(f"All records from {args.start} to {args.end} have been processed. Exiting.")
        exit(0)

    for idx, item in tqdm(enumerate(data[start_index:args.end], start=start_index),
                          total=min(len(data), args.end) - start_index,
                          desc="Processing instances", unit="instance"):
        start_time = time.time()  # Record start time for this iteration

        generator = IterativeFactorGenerator(
            args.model_name, args.target, args.batch, args.rounds, args.cluster
        )
        if 'common' in args.dataset_name:
            results = generator.run(item['scenario'], item['statement'], item['opposite_statement'])
        else:
            results = generator.run(item['scenario'], item['statement'])

        end_time = time.time()  # Record end time for this iteration
        elapsed_time = end_time - start_time  # Calculate elapsed time in seconds

        out_objs.append({
            'scenario': item['scenario'],
            'statement': item['statement'],
            # 'opposite_statement': item['opposite_statement'],
            'factors_before_clustering': results['factors_before_clustering'],
            'factors_after_clustering': results['factors_after_clustering'],
            # 'init_factor_label': results['init_factor_label'],
            'factor_statement_mapping': results['factor_statement_mapping'],
            'generated_sentences': results['generated_sentences'],
            # 'voting_records': results['voting_records'],
            'pruning_records': results['pruning_records'],  # Add pruning raw responses
            'clustering_applied': results['clustering_applied'],
            'clustering_stats': results['clustering_stats'],
            'pruned_clusters': results['pruned_clusters'],
            'elapsed_time': elapsed_time  # Add elapsed time field
        })

        with open(out_path, 'w', encoding='utf-8') as outf:
            json.dump(out_objs, outf, indent=2, ensure_ascii=False)
        print(f"[Saved] index {idx} to {out_path}")