
import umap.umap_ as umap
import hdbscan
from sentence_transformers import SentenceTransformer
import sys
import os
import json
import time
import argparse
import re
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from collections import defaultdict
import string
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import logging

# Download required NLTK data
# try:
#     nltk.data.find('tokenizers/punkt')
# except LookupError:
#     nltk.download('punkt')
#
# try:
#     nltk.data.find('corpora/stopwords')
# except LookupError:
#     nltk.download('stopwords')

from utils import config
from utils.prompt_shots import  MAPPING_REFLECTION_SHOT
from utils.prompt_shots_fact_check import MAPPING_VOTE_SHOT_XSUM, MAPPING_VOTE_SHOT_CNN, MAPPING_VOTE_SHOT_EXPERTQA, \
    MAPPING_VOTE_SHOT_COVID
from utils.utils import ask_gpt, print_args, correct_dict_keys

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress httpx logging
logging.getLogger("httpx").setLevel(logging.WARNING)


class TextPreprocessor:
    """Text preprocessing utilities for cleaning and tokenization"""

    def __init__(self, language='english'):
        self.stop_words = set(stopwords.words(language))
        self.punctuation = set(string.punctuation)

    def clean_text(self, text):
        """Clean text: lowercase, remove punctuation, remove stopwords"""
        # Lowercase
        text = text.lower()

        # Tokenize
        tokens = word_tokenize(text)

        # Remove punctuation and stopwords
        tokens = [token for token in tokens
                  if token not in self.punctuation and token not in self.stop_words]

        return ' '.join(tokens)

    def preprocess_batch(self, texts):
        """Preprocess a batch of texts"""
        return [self.clean_text(text) for text in texts]


class HierarchicalConditionMapper:
    """
    Hierarchical condition mapping with mathematical optimization
    Following the mathematical framework:
    1. Text preprocessing and encoding
    2. Cluster prototype construction
    3. Hierarchical retrieval (cluster-level + factor-level)
    4. LLM mapping with majority voting
    5. Self-reflection correction
    
    Enhanced to handle both clustered and unclustered factors separately
    """

    def __init__(self, model_name, K1=3, K2=5, num_prompts=3, vote_ratio=0.5, alpha=0.5,
                 max_retries=3, enable_preprocessing=False, enable_self_reflection=True):
        self.model_name = model_name
        self.K1 = K1  # Top-K clusters
        self.K2 = K2  # Top-K factors per cluster
        self.R = num_prompts  # Number of LLM calls
        self.alpha = alpha  # Weight for cluster name vs factor content
        self.max_retries = max_retries  # Maximum retries for LLM parsing
        self.enable_preprocessing = enable_preprocessing  # Enable/disable text preprocessing
        self.enable_self_reflection = enable_self_reflection  # Enable/disable self-reflection

        # Vote threshold: τ ≈ ⌈(2/3)R⌉, but at least 1
        self.vote_threshold = max(1, int(np.ceil(self.R * vote_ratio)))

        # Initialize components
        if self.enable_preprocessing:
            self.preprocessor = TextPreprocessor()
        else:
            self.preprocessor = None

        self.embedder = SentenceTransformer(config.MiniLM_L6_PATH)
        # Disable progress bars for cleaner output
        self.embedder.encode = self._wrap_encode_method(self.embedder.encode)

        # Storage for fitted model - clustered factors
        self.cluster_names = []
        self.cluster_prototypes = None  # e_Cj
        self.cluster_factor_embeddings = {}  # Factor embeddings per cluster
        self.cluster_factors = {}  # Factor texts per cluster
        self.knn_clusters = None  # KNN for cluster-level retrieval
        self.knn_factors = {}  # KNN for factor-level retrieval per cluster
        
        # Storage for unclustered factors
        self.unclustered_factors = []
        self.unclustered_factor_embeddings = None
        self.knn_unclustered = None  # KNN for unclustered factors

        logger.info(
            f"🚀 Initializing Hierarchical Condition Mapper: K1={K1}, K2={K2}, R={num_prompts}, α={alpha}, "
            f"max_retries={max_retries}, preprocessing={'enabled' if enable_preprocessing else 'disabled'}, "
            f"self_reflection={'enabled' if enable_self_reflection else 'disabled'}")

    def _wrap_encode_method(self, original_encode):
        """Wrap the encode method to disable progress bars"""

        def encode_wrapper(*args, **kwargs):
            # Set show_progress_bar to False if not explicitly set
            if 'show_progress_bar' not in kwargs:
                kwargs['show_progress_bar'] = False
            return original_encode(*args, **kwargs)

        return encode_wrapper

    def _encode_texts(self, texts, normalize=True):
        """Encode texts to embeddings: e_s = Encode(s) ∈ R^d"""
        if isinstance(texts, str):
            texts = [texts]

        # Preprocess texts only if enabled
        if self.enable_preprocessing and self.preprocessor:
            cleaned_texts = self.preprocessor.preprocess_batch(texts)
        else:
            cleaned_texts = texts

        # Encode with sentence transformer
        embeddings = self.embedder.encode(cleaned_texts, normalize_embeddings=normalize)

        # Always return 2D array for compatibility with sklearn
        if len(texts) == 1:
            # For single text, return 2D array with shape (1, embedding_dim)
            return embeddings.reshape(1, -1) if len(embeddings.shape) == 1 else embeddings
        else:
            # For multiple texts, ensure we have 2D array
            return embeddings

    def _compute_cluster_prototype(self, cluster_name, factor_texts):
        """
        Compute cluster prototype:
        e_Cj = α * Encode(Name(Cj)) + (1-α) * (1/|Fj|) * Σ_{f∈Fj} e_f
        """
        # Encode cluster name (single text, will be 2D array with shape (1, dim))
        name_embedding = self._encode_texts(cluster_name, normalize=True)
        # Extract the embedding vector for computation
        name_embedding_vector = name_embedding[0] if len(name_embedding.shape) == 2 else name_embedding

        # Encode factor texts (will be 2D array)
        factor_embeddings = self._encode_texts(factor_texts, normalize=True)

        # Compute mean factor embedding
        mean_factor_embedding = np.mean(factor_embeddings, axis=0)

        # Weighted combination
        prototype = (self.alpha * name_embedding_vector +
                     (1 - self.alpha) * mean_factor_embedding)

        # Normalize
        prototype = prototype / np.linalg.norm(prototype)

        return prototype, factor_embeddings

    def fit(self, initial_clusters, unclustered_factors=None):
        """
        Fit the hierarchical mapper on initial clusters and unclustered factors
        Constructs cluster prototypes and KNN indices for both clustered and unclustered factors
        """
        logger.info("🔧 Starting to fit Hierarchical Condition Mapper...")

        # Correct keys
        correct_dict_keys(initial_clusters)

        # Process clustered factors
        if initial_clusters:
            logger.info("📊 Processing clustered factors...")
            # Compute cluster prototypes
            cluster_names = []
            cluster_prototypes = []

            for cluster_name, factor_texts in initial_clusters.items():
                logger.debug(f"  Processing cluster: {cluster_name} ({len(factor_texts)} factors)")
                cluster_names.append(cluster_name)

                # Compute prototype and factor embeddings
                prototype, factor_embeddings = self._compute_cluster_prototype(
                    cluster_name, factor_texts
                )

                cluster_prototypes.append(prototype)

                # Ensure factor_embeddings is 2D array
                if len(factor_embeddings.shape) == 1:
                    factor_embeddings = factor_embeddings.reshape(1, -1)
                
                # Store factor embeddings and texts
                self.cluster_factor_embeddings[cluster_name] = factor_embeddings
                self.cluster_factors[cluster_name] = factor_texts

            # Store cluster information
            self.cluster_names = cluster_names
            self.cluster_prototypes = np.vstack(cluster_prototypes)

            # Build cluster-level KNN: C(s) = arg topK_j sim(s, Cj)
            n_clusters = min(self.K1, len(cluster_names))
            logger.info(f"🔍 Building cluster-level KNN index (K1={n_clusters})...")
            self.knn_clusters = NearestNeighbors(
                n_neighbors=n_clusters,
                metric='cosine'
            ).fit(self.cluster_prototypes)

            # Build factor-level KNN per cluster: Fj(s) = arg topK_{f∈Fj} sim(s, f)
            logger.info("🔍 Building factor-level KNN index...")
            for cluster_name, factor_embeddings in self.cluster_factor_embeddings.items():
                n_factors = min(self.K2, factor_embeddings.shape[0])
                self.knn_factors[cluster_name] = NearestNeighbors(
                    n_neighbors=n_factors,
                    metric='cosine'
                ).fit(factor_embeddings)

        # Process unclustered factors
        if unclustered_factors:
            logger.info(f"📋 Processing unclustered factors ({len(unclustered_factors)} items)...")
            self.unclustered_factors = unclustered_factors
            self.unclustered_factor_embeddings = self._encode_texts(unclustered_factors, normalize=True)
            
            # Ensure unclustered_factor_embeddings is 2D
            if len(self.unclustered_factor_embeddings.shape) == 1:
                self.unclustered_factor_embeddings = self.unclustered_factor_embeddings.reshape(1, -1)
            
            # Build KNN for unclustered factors
            n_unclustered = min(self.K2, len(unclustered_factors))
            logger.info(f"🔍 Building unclustered factors KNN index (K2={n_unclustered})...")
            self.knn_unclustered = NearestNeighbors(
                n_neighbors=n_unclustered,
                metric='cosine'
            ).fit(self.unclustered_factor_embeddings)

        clustered_count = sum(len(factors) for factors in initial_clusters.values()) if initial_clusters else 0
        unclustered_count = len(unclustered_factors) if unclustered_factors else 0
        
        logger.info(
            f"✅ Fitting complete: {len(self.cluster_names)} clusters ({clustered_count} clustered factors), "
            f"{unclustered_count} unclustered factors"
        )

    def _hierarchical_retrieval_clustered(self, condition_text):
        """
        Hierarchical retrieval for clustered factors:
        1. Cluster-level: C(s) = arg topK_j sim(s, Cj)
        2. Factor-level: F(s) = ∪_{Cj∈C(s)} Fj(s)
        """
        if not self.cluster_names:
            logger.debug("🔍 No cluster data, skipping cluster retrieval")
            return [], [], {}, {}
            
        logger.debug(f"🔍 Starting hierarchical cluster retrieval: '{condition_text[:50]}...'")
        
        # Encode condition (single text, will be 2D array)
        condition_embedding = self._encode_texts(condition_text, normalize=True)

        # Step 1: Cluster-level retrieval
        logger.debug("  📊 Step 1: Cluster-level retrieval...")
        cluster_distances, cluster_indices = self.knn_clusters.kneighbors(
            condition_embedding, return_distance=True
        )

        top_cluster_names = [self.cluster_names[i] for i in cluster_indices[0]]
        cluster_similarities = [1 - dist for dist in cluster_distances[0]]  # Convert distance to similarity
        
        # Log cluster KNN results
        logger.debug(f"  📊 KNN cluster retrieval results:")
        for i, (cluster_name, similarity) in enumerate(zip(top_cluster_names, cluster_similarities)):
            logger.debug(f"    Rank {i+1}: {cluster_name} (similarity: {similarity:.4f})")

        # Step 2: Factor-level retrieval
        logger.debug("  🔍 Step 2: Factor-level retrieval...")
        candidate_factors = []
        factor_cluster_mapping = {}  # Track which cluster each factor comes from
        knn_results = {
            'cluster_level': {
                'top_clusters': top_cluster_names,
                'cluster_similarities': cluster_similarities,
                'cluster_distances': cluster_distances[0].tolist()
            },
            'factor_level': {}
        }

        for cluster_idx, cluster_name in enumerate(top_cluster_names):
            factor_embeddings = self.cluster_factor_embeddings[cluster_name]
            factor_texts = self.cluster_factors[cluster_name]
            knn_factors = self.knn_factors[cluster_name]

            # Find top-K factors in this cluster
            factor_distances, factor_indices = knn_factors.kneighbors(
                condition_embedding, return_distance=True
            )

            cluster_candidates = []
            factor_similarities = [1 - dist for dist in factor_distances[0]]
            
            # Log factor KNN results for this cluster
            logger.debug(f"    Cluster '{cluster_name}' KNN retrieval results:")
            for i, (idx, similarity) in enumerate(zip(factor_indices[0], factor_similarities)):
                factor_text = factor_texts[idx]
                logger.debug(f"      Rank {i+1}: {factor_text[:60]}... (similarity: {similarity:.4f})")
                
                if factor_text not in candidate_factors:
                    candidate_factors.append(factor_text)
                    factor_cluster_mapping[factor_text] = cluster_name
                    cluster_candidates.append(factor_text)
            
            # Store KNN results for this cluster
            knn_results['factor_level'][cluster_name] = {
                'top_factors': [factor_texts[idx] for idx in factor_indices[0]],
                'factor_similarities': factor_similarities,
                'factor_distances': factor_distances[0].tolist(),
                'selected_factors': cluster_candidates
            }
            
            logger.debug(f"    Retrieved {len(cluster_candidates)} candidate factors from cluster '{cluster_name}'")

        logger.debug(f"🔍 Cluster retrieval complete: Total {len(candidate_factors)} candidate factors")
        return top_cluster_names, candidate_factors, factor_cluster_mapping, knn_results

    def _retrieval_unclustered(self, condition_text):
        """
        Direct retrieval for unclustered factors:
        F(s) = arg topK_f sim(s, f) for unclustered factors
        """
        if not self.unclustered_factors:
            logger.debug("🔍 No unclustered factors, skipping unclustered retrieval")
            return [], {}
            
        logger.debug(f"🔍 Starting unclustered factors retrieval: '{condition_text[:50]}...'")
        
        # Encode condition (single text, will be 2D array)
        condition_embedding = self._encode_texts(condition_text, normalize=True)

        # Direct factor retrieval
        factor_distances, factor_indices = self.knn_unclustered.kneighbors(
            condition_embedding, return_distance=True
        )

        candidate_factors = [self.unclustered_factors[idx] for idx in factor_indices[0]]
        factor_similarities = [1 - dist for dist in factor_distances[0]]
        
        # Log unclustered KNN results
        logger.debug(f"🔍 KNN unclustered factors retrieval results:")
        for i, (factor, similarity) in enumerate(zip(candidate_factors, factor_similarities)):
            logger.debug(f"  Rank {i+1}: {factor[:60]}... (similarity: {similarity:.4f})")
        
        # Store KNN results
        knn_results = {
            'top_factors': candidate_factors,
            'factor_similarities': factor_similarities,
            'factor_distances': factor_distances[0].tolist()
        }
        
        logger.debug(f"🔍 Unclustered retrieval complete: {len(candidate_factors)} candidate factors")
        return candidate_factors, knn_results

    def _parse_llm_response(self, response_text, candidates):
        """
        Parse LLM response with multiple strategies and error handling
        Returns (parsed_result, raw_response)
        Now handles factor lists instead of single factors
        """
        raw_response = response_text.strip()

        # Strategy 1: Extract JSON after "Final answer:" marker
        final_answer_pattern = r'Final answer:\s*(\{[^{}]*"answer"\s*:\s*\[[^\]]*\][^{}]*\})'
        final_answer_match = re.search(final_answer_pattern, raw_response, re.IGNORECASE | re.DOTALL)
        if final_answer_match:
            try:
                result = json.loads(final_answer_match.group(1))
                if isinstance(result, dict) and 'answer' in result:
                    answer = result.get('answer')
                    if isinstance(answer, list):
                        # Validate that all factors are in candidates
                        valid_factors = [f for f in answer if f in candidates]
                        return valid_factors if valid_factors else [], raw_response
                    elif isinstance(answer, str) and answer in candidates:
                        return [answer], raw_response
            except json.JSONDecodeError:
                pass

        # Strategy 2: Try to parse the entire response as JSON
        try:
            result = json.loads(raw_response)
            if isinstance(result, dict) and 'answer' in result:
                answer = result.get('answer')
                if isinstance(answer, list):
                    valid_factors = [f for f in answer if f in candidates]
                    return valid_factors if valid_factors else [], raw_response
                elif isinstance(answer, str) and answer in candidates:
                    return [answer], raw_response
        except json.JSONDecodeError:
            pass

        # Strategy 3: Extract JSON array pattern from anywhere in text
        json_array_patterns = [
            r'\{[^{}]*"answer"\s*:\s*(\[[^\]]*\])[^{}]*\}',  # {"answer": [...]}
            r'"answer"\s*:\s*(\[[^\]]*\])',  # "answer": [...]
            r'answer["\']?\s*:\s*(\[[^\]]*\])',  # answer: [...]
        ]

        for pattern in json_array_patterns:
            matches = re.findall(pattern, raw_response, re.IGNORECASE | re.DOTALL)
            for match in matches:
                try:
                    factor_list = json.loads(match)
                    if isinstance(factor_list, list):
                        valid_factors = [f for f in factor_list if f in candidates]
                        return valid_factors if valid_factors else [], raw_response
                except json.JSONDecodeError:
                    continue

        # Strategy 4: Look for individual candidate factors mentioned in response
        # and try to extract them as a list
        mentioned_factors = []
        raw_lower = raw_response.lower()
        for candidate in candidates:
            if candidate.lower() in raw_lower:
                mentioned_factors.append(candidate)
        
        if mentioned_factors:
            return mentioned_factors, raw_response

        # Strategy 5: Look for empty list indicators
        empty_patterns = [
            r'\[\s*\]',  # []
            r'\bnone\b',
            r'\bno\s+factor\b',
            r'\bnot\s+applicable\b',
            r'\bn/a\b',
            r'empty\s+list'
        ]

        for pattern in empty_patterns:
            if re.search(pattern, raw_response, re.IGNORECASE):
                return [], raw_response

        # If all strategies fail, return empty list
        return [], raw_response

    def _llm_mapping_with_voting(self, scenario, condition, candidates):
        """
        LLM mapping with majority voting:
        m^(r)(s) = LLM(s, F~^(r)(s)) ∈ F(s) ∪ {None}
        v_f(s) = Σ_{r=1}^R 1(m^(r)(s) = f)
        """
        logger.debug(f"🗳️ Starting LLM voting mapping (R={self.R} rounds): '{condition[:50]}...'")
        logger.debug(f"  Candidate factors: {len(candidates)} items")
        logger.debug(f"  Candidate factors list: {candidates}")
        
        votes = defaultdict(int)
        raw_responses = []  # Store all raw responses
        parsing_details = []  # Store parsing details for debugging

        if 'xsum' in config.dataset_name:
            MAPPING_VOTE_SHOT = MAPPING_VOTE_SHOT_XSUM.copy()
        elif 'cnn' in config.dataset_name:
            MAPPING_VOTE_SHOT = MAPPING_VOTE_SHOT_CNN.copy()
        elif 'expertqa' in config.dataset_name:
            MAPPING_VOTE_SHOT = MAPPING_VOTE_SHOT_EXPERTQA.copy()
        elif 'covid' in config.dataset_name:
            MAPPING_VOTE_SHOT = MAPPING_VOTE_SHOT_COVID.copy()

        for round_idx in range(self.R):
            logger.debug(f"  🗳️ Voting round {round_idx + 1}/{self.R}...")
            # Prepare prompt
            # Current prompt
            prompt = (
                f"Scenario: {scenario}\n"
                f"Condition: {condition}\n"
                f"Candidate factors: {json.dumps(candidates, ensure_ascii=False)}\n\n"
                f"Task: Select the factor list that is most reasonably related to the given condition. "
                f"Consider factors that could be logically connected, influenced by, or relevant to the condition. "
                f"Think step by step about each factor's relevance, then provide your selection. "
                f"If no factor has any reasonable connection, return an empty list [].\n"
                f"Please respond with your reasoning followed by 'Final answer:' and a JSON object in this exact format: {{"answer": ["factor1", "factor2", ...]}}"
            )

            messages = [
                {
                    "role": "system",
                    "content": (
                        "You are an expert at analyzing logical relationships between conditions and factors. "
                        "Given a condition and a list of candidate factors, select a list of factors that have the most "
                        "reasonable connection to the condition. Consider direct implications, indirect effects, "
                        "causal relationships, or thematic relevance. Be inclusive rather than restrictive - "
                        "if a factor could plausibly be related to the condition, include it. "
                        "Think step by step about each factor, then provide your final answer. "
                        "Always respond with your reasoning followed by 'Final answer:' and a JSON object in the format: {\"answer\": [\"factor1\", \"factor2\", ...]}"
                    )
                }
            ] + MAPPING_VOTE_SHOT + [{"role": "user", "content": prompt}]

            selected_factors = []
            raw_response = ""

            # Try with retries
            for retry in range(self.max_retries):
                try:
                    logger.debug(f"    🔄 Sending LLM request (retry {retry + 1}/{self.max_retries})...")
                    response = ask_gpt(messages, model_name=self.model_name, max_token=512)  # Increased token limit for reasoning
                    logger.debug(f"    📨 Received LLM response: {len(response)} characters")
                    parsed_result, raw_response = self._parse_llm_response(response, candidates)

                    if parsed_result is not None:
                        selected_factors = parsed_result if isinstance(parsed_result, list) else [parsed_result]
                        logger.debug(f"    ✅ Parsing successful: {selected_factors}")
                        break
                    else:
                        logger.debug(f"    ❌ Parsing failed (retry {retry + 1}/{self.max_retries})")
                        if retry < self.max_retries - 1:
                            logger.debug(f"    Raw response: {response[:200]}...")

                except Exception as e:
                    logger.debug(f"    ❌ LLM call failed (retry {retry + 1}/{self.max_retries}): {e}")
                    raw_response = f"Error: {str(e)}"

            # Record raw response
            raw_responses.append(raw_response)

            # Record parsing details
            parsing_details.append({
                'round': round_idx,
                'raw_response': raw_response,
                'parsed_result': selected_factors,
                'retries_used': retry + 1 if selected_factors else self.max_retries
            })

            # Count votes for each selected factor
            if selected_factors:
                for factor in selected_factors:
                    if factor in candidates:
                        votes[factor] += 1
                        logger.debug(f"    📊 Vote recorded: {factor}")
                logger.debug(f"    📊 {len(selected_factors)} factors selected this round")
            else:
                votes['None'] += 1
                logger.debug(f"    📊 Vote recorded: None (no factors selected)")

        # Apply voting threshold: F^(s) = {f | v_f(s) ≥ τ}
        logger.debug(f"🗳️ Voting results: {dict(votes)}")
        logger.debug(f"🗳️ Voting threshold: {self.vote_threshold}")
        
        # Primary selection: factors meeting the threshold
        primary_selected = [
            factor for factor, vote_count in votes.items()
            if vote_count >= self.vote_threshold and factor != 'None'
        ]
        
        # Secondary selection: if primary selection is empty or too few, 
        # include factors with at least 1 vote (but not 'None')
        if len(primary_selected) == 0:
            logger.debug("🗳️ Primary selection empty, using lenient selection...")
            selected_factors = [
                factor for factor, vote_count in votes.items()
                if vote_count >= 1 and factor != 'None'
            ]
        else:
            selected_factors = primary_selected
        
        logger.debug(f"🗳️ Voting selection results: {len(selected_factors)} factors passed selection")

        return selected_factors, votes, raw_responses, parsing_details

    def _self_reflection(self, condition, selected_factors):
        """
        Self-reflection correction:
        F*(s) = ReflectLLM(s, F^(s))
        """
        if not selected_factors:
            logger.debug("🤔 No need for self-reflection: No factors selected")
            return selected_factors, "", []
            
        logger.debug(f"🤔 Starting self-reflection: '{condition[:50]}...'")
        logger.debug(f"  Initially selected factors: {len(selected_factors)} items")

        prompt = (
            f"Condition: {condition}\n"
            f"Initially selected factors: {json.dumps(selected_factors, ensure_ascii=False)}\n\n"
            f"Task: Review the selected factors and keep those that are reasonably related to the condition. "
            f"A factor should be kept if it has ANY logical connection or relevance to the condition, "
            f"even if the connection is not immediately obvious. Only remove factors that are clearly "
            f"irrelevant or contradictory. When in doubt, KEEP the factor.\n"
            f"Think step by step about each factor's relevance to the condition, then provide your final selection.\n"
            f"Please respond with your reasoning followed by 'Final answer:' and a JSON array containing the factors to keep."
        )

        messages = [
            {
                "role": "system",
                "content": (
                    "You are performing a self-reflection task. Given a condition and a list of initially "
                    "selected factors, review each factor with a LENIENT approach. Keep factors that have "
                    "ANY reasonable connection to the condition. Only remove factors that are clearly "
                    "irrelevant or contradictory. When uncertain, err on the side of keeping the factor. "
                    "Think step by step about each factor, then provide your reasoning followed by "
                    "'Final answer:' and a JSON array of the factors to keep."
                )
            }
        ] + MAPPING_REFLECTION_SHOT + [{"role": "user", "content": prompt}]

        raw_response = ""
        parsing_details = []

        for retry in range(self.max_retries):
            try:
                logger.debug(f"  🔄 Sending self-reflection LLM request (retry {retry + 1}/{self.max_retries})...")
                response = ask_gpt(messages, model_name=self.model_name, max_token=512)  # Increased for reasoning
                logger.debug(f"  📨 Received self-reflection LLM response: {len(response)} characters")
                raw_response = response.strip()

                # Strategy 1: Extract JSON array after "Final answer:" marker
                final_answer_pattern = r'Final answer:\s*(\[[^\]]*\])'
                final_answer_match = re.search(final_answer_pattern, raw_response, re.IGNORECASE | re.DOTALL)
                if final_answer_match:
                    try:
                        refined_factors = json.loads(final_answer_match.group(1))
                        if isinstance(refined_factors, list):
                            # Ensure all returned factors are from the original selection
                            final_factors = [
                                factor for factor in refined_factors
                                if factor in selected_factors
                            ]

                            parsing_details.append({
                                'retry': retry,
                                'raw_response': raw_response,
                                'parsed_successfully': True,
                                'refined_factors': refined_factors,
                                'final_factors': final_factors
                            })

                            logger.debug(f"🤔 Self-reflection complete: {len(final_factors)} factors retained")
                            return final_factors, raw_response, parsing_details
                    except json.JSONDecodeError:
                        pass

                # Strategy 2: Try to parse as JSON array directly
                try:
                    refined_factors = json.loads(raw_response)
                    if isinstance(refined_factors, list):
                        # Ensure all returned factors are from the original selection
                        final_factors = [
                            factor for factor in refined_factors
                            if factor in selected_factors
                        ]

                        parsing_details.append({
                            'retry': retry,
                            'raw_response': raw_response,
                            'parsed_successfully': True,
                            'refined_factors': refined_factors,
                            'final_factors': final_factors
                        })

                        logger.debug(f"🤔 Self-reflection complete: {len(final_factors)} factors retained")
                        return final_factors, raw_response, parsing_details

                except json.JSONDecodeError:
                    pass

                # Strategy 3: Try to extract JSON array from text
                array_patterns = [
                    r'\[([^\[\]]*(?:"[^"]*"[^\[\]]*)*)\]',  # Standard array pattern
                    r'keep[:\s]*(\[[^\]]*\])',  # "keep: [...]"
                    r'factors[:\s]*(\[[^\]]*\])',  # "factors: [...]"
                ]
                
                for pattern in array_patterns:
                    matches = re.findall(pattern, raw_response, re.IGNORECASE | re.DOTALL)
                    for match in matches:
                        try:
                            if not match.startswith('['):
                                array_text = match
                            else:
                                array_text = f'[{match}]'
                            refined_factors = json.loads(array_text)
                            if isinstance(refined_factors, list):
                                final_factors = [
                                    factor for factor in refined_factors
                                    if factor in selected_factors
                                ]

                                parsing_details.append({
                                    'retry': retry,
                                    'raw_response': raw_response,
                                    'parsed_successfully': True,
                                    'refined_factors': refined_factors,
                                    'final_factors': final_factors
                                })

                                logger.debug(f"🤔 Self-reflection complete: {len(final_factors)} factors retained")
                                return final_factors, raw_response, parsing_details
                        except json.JSONDecodeError:
                            continue

                logger.debug(f"🤔 Self-reflection parsing failed (retry {retry + 1}/{self.max_retries})")
                if retry < self.max_retries - 1:
                    logger.debug(f"  Raw response: {response[:200]}...")

            except Exception as e:
                logger.debug(f"🤔 Self-reflection LLM call failed (retry {retry + 1}/{self.max_retries}): {e}")
                raw_response = f"Error: {str(e)}"

        # If all retries failed, return original selection
        parsing_details.append({
            'retry': self.max_retries,
            'raw_response': raw_response,
            'parsed_successfully': False,
            'error': 'All parsing attempts failed',
            'fallback_to_original': True
        })

        logger.debug(f"🤔 Self-reflection failed, retaining original selection: {len(selected_factors)} factors")

        return selected_factors, raw_response, parsing_details

    def map_condition_clustered(self, scenario, condition):
        """
        Map a single condition to clustered factors
        """
        # Step 1: Hierarchical retrieval for clustered factors
        top_clusters, candidates, factor_cluster_mapping, knn_results = self._hierarchical_retrieval_clustered(condition)

        if not candidates:
            return {
                'condition': condition,
                'top_clusters': [],
                'candidate_factors': [],
                'selected_factors': [],
                'final_factors': [],
                'relevant_clusters': [],
                'voting_details': {},
                'raw_responses': {'voting_rounds': [], 'self_reflection': ''},
                'parsing_details': {'voting_rounds': [], 'self_reflection': []},
                'knn_results': {}
            }

        # Step 2: LLM mapping with voting
        selected_factors, voting_details, raw_responses, parsing_details = self._llm_mapping_with_voting(
            scenario, condition, candidates
        )

        # Step 3: Self-reflection (optional)
        if self.enable_self_reflection:
            final_factors, reflection_raw_response, reflection_parsing_details = self._self_reflection(
                condition, selected_factors
            )
        else:
            logger.debug("🤔 Skipping self-reflection step")
            final_factors = selected_factors
            reflection_raw_response = "Self-reflection disabled"
            reflection_parsing_details = [{"disabled": True}]

        # Step 4: Determine relevant clusters
        relevant_clusters = list(set(
            factor_cluster_mapping[factor]
            for factor in final_factors
            if factor in factor_cluster_mapping
        ))

        return {
            'condition': condition,
            'top_clusters': top_clusters,
            'candidate_factors': candidates,
            'selected_factors': selected_factors,
            'final_factors': final_factors,
            'candidate_factor_statement_mapping':{},
            'relevant_clusters': relevant_clusters,
            'voting_details': dict(voting_details),
            'raw_responses': {
                'voting_rounds': raw_responses,
                'self_reflection': reflection_raw_response
            },
            'parsing_details': {
                'voting_rounds': parsing_details,
                'self_reflection': reflection_parsing_details
            },
            'knn_results': knn_results
        }

    def map_condition_unclustered(self, scenario, condition):
        """
        Map a single condition to unclustered factors
        """
        # Step 1: Direct retrieval for unclustered factors
        candidates, knn_results = self._retrieval_unclustered(condition)

        if not candidates:
            return {
                'condition': condition,
                'candidate_factors': [],
                'selected_factors': [],
                'final_factors': [],
                'voting_details': {},
                'raw_responses': {'voting_rounds': [], 'self_reflection': ''},
                'parsing_details': {'voting_rounds': [], 'self_reflection': []},
                'knn_results': {}
            }

        # Step 2: LLM mapping with voting
        selected_factors, voting_details, raw_responses, parsing_details = self._llm_mapping_with_voting(
            scenario, condition, candidates
        )

        # Step 3: Self-reflection (optional)
        if self.enable_self_reflection:
            final_factors, reflection_raw_response, reflection_parsing_details = self._self_reflection(
                condition, selected_factors
            )
        else:
            logger.debug("🤔 Skipping self-reflection step")
            final_factors = selected_factors
            reflection_raw_response = "Self-reflection disabled"
            reflection_parsing_details = [{"disabled": True}]

        return {
            'condition': condition,
            'candidate_factors': candidates,
            'selected_factors': selected_factors,
            'final_factors': final_factors,
            'voting_details': dict(voting_details),
            'raw_responses': {
                'voting_rounds': raw_responses,
                'self_reflection': reflection_raw_response
            },
            'parsing_details': {
                'voting_rounds': parsing_details,
                'self_reflection': reflection_parsing_details
            },
            'knn_results': knn_results
        }

    def map_conditions_dual(self, statement, conditions):
        """
        Map multiple conditions for both clustered and unclustered factors
        Now handles factor lists instead of single factors
        """
        logger.info(f"🎯 Starting dual condition mapping: {len(conditions)} conditions")
        
        clustered_results = {}
        unclustered_results = {}

        for i, condition in enumerate(tqdm(conditions, desc="Mapping conditions"), 1):
            logger.info(f"🎯 Processing condition {i}/{len(conditions)}: '{condition[:50]}...'")
            
            # Map to clustered factors
            logger.info("  📊 Mapping to clustered factors...")
            clustered_result = self.map_condition_clustered(statement, condition)
            clustered_results[condition] = clustered_result
            logger.info(f"    ✅ Clustered mapping complete: {len(clustered_result['final_factors'])} final factors")
            
            # Map to unclustered factors
            logger.info("  📋 Mapping to unclustered factors...")
            unclustered_result = self.map_condition_unclustered(statement,condition)
            unclustered_results[condition] = unclustered_result
            logger.info(f"    ✅ Unclustered mapping complete: {len(unclustered_result['final_factors'])} final factors")

        logger.info(f"🎯 Dual condition mapping complete!")
        return clustered_results, unclustered_results

    # Keep the original methods for backward compatibility
    def map_condition(self, scenario, statement, opposite_statement, condition):
        """
        Map a single condition to factors and clusters (backward compatibility)
        """
        return self.map_condition_clustered(scenario, statement, opposite_statement, condition)

    def map_conditions(self, scenario, statement, opposite_statement, conditions):
        """
        Map multiple conditions for a given scenario (backward compatibility)
        """
        results = {}
        for condition in tqdm(conditions, desc="Mapping conditions"):
            result = self.map_condition(scenario, statement, opposite_statement, condition)
            results[condition] = result
        return results


def parse_args():
    parser = argparse.ArgumentParser(description="Hierarchical Condition Mapping with Mathematical Optimization")
    parser.add_argument("--model_name", type=str, default=config.model_name)
    parser.add_argument("--dataset_name", type=str, default=config.dataset_name)
    parser.add_argument("--dataset_file_dic", type=str, default=config.dataset_file_dic)
    parser.add_argument("--save_file_dic", type=str, default=config.save_file_dic)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=1000)
    parser.add_argument("--K1", type=int, default=3, help="Number of top clusters to retrieve")
    parser.add_argument("--K2", type=int, default=5, help="Number of top factors per cluster")
    parser.add_argument("--num_prompts", type=int, default=3, help="Number of LLM voting rounds")
    parser.add_argument("--vote_ratio", type=float, default=0.5, help="Voting threshold ratio (default: 0.5 for more inclusive selection)")
    parser.add_argument("--alpha", type=float, default=0.5, help="Weight for cluster name vs factor content")
    parser.add_argument("--max_retries", type=int, default=5, help="Maximum retries for LLM response parsing")
    parser.add_argument("--enable_preprocessing", action='store_true', help="Enable text preprocessing")
    parser.add_argument("--enable_self_reflection", action='store_true', default=True, help="Enable self-reflection step")
    parser.add_argument("--disable_self_reflection", dest='enable_self_reflection', action='store_false', help="Disable self-reflection step")
    parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], 
                       help="Set logging level")

    args = parser.parse_args()
    print_args(args)
    return args


def main():
    args = parse_args()
    suffix = ''
    # Set logging level based on argument
    numeric_level = getattr(logging, args.log_level.upper(), None)
    if isinstance(numeric_level, int):
        logger.setLevel(numeric_level)
        # Also set the root logger level
        logging.getLogger().setLevel(numeric_level)
        
    logger.info(f"📋 Logging level set to: {args.log_level}")
    logger.info("🚀 Starting hierarchical condition mapping...")

    # File paths
    basic_infile = os.path.join(args.dataset_file_dic, args.dataset_name + '.json')
    
    factor_file = f"{args.dataset_name}_{args.model_name.replace(':', '-')}_0_1000_factors{suffix}.json"

    

    factor_infile = os.path.join(args.save_file_dic, factor_file)

    logger.info(f'📋Basic input file: {basic_infile}')
    logger.info(f'📋Factor input file: {factor_infile}')

    # Load data
    with open(basic_infile, 'r', encoding='utf-8') as f:
        data = json.load(f)

    with open(factor_infile, 'r', encoding='utf-8') as f:
        df_factor = json.load(f)

    # Create factor dictionary
    df_factor_dict = {}
    for df in df_factor:
        key = df['scenario'] + df['statement']
        df_factor_dict[key] = df

    # Initialize mapper
    mapper = HierarchicalConditionMapper(
        model_name=args.model_name,
        K1=args.K1,
        K2=args.K2,
        num_prompts=args.num_prompts,
        vote_ratio=args.vote_ratio,
        alpha=args.alpha,
        max_retries=args.max_retries,
        enable_preprocessing=args.enable_preprocessing,
        enable_self_reflection=args.enable_self_reflection
    )

    # Process data
    os.makedirs(args.save_file_dic, exist_ok=True)

    out_path = os.path.join(
        args.save_file_dic,
        # f"{args.dataset_name}_{args.model_name.replace(':', '-')}_{args.start}_{args.end}_condition_mapping.json"
        f"{args.dataset_name}_{args.model_name.replace(':', '-')}_{args.start}_{args.end}_condition_mapping.json"
    )
    logger.info(f'📋Output file: {out_path}')

    # Resume from checkpoint: check if there are processed data
    out_objs = []
    start_index = args.start

    if os.path.exists(out_path):
        try:
            with open(out_path, 'r', encoding='utf-8') as f:
                existing_data = json.load(f)

            if existing_data:
                out_objs = existing_data
                # Find the maximum processed index
                processed_scenarios = set()
                for obj in existing_data:
                    # Identify processed items by scenario and statement combination
                    key = obj['scenario'] + obj['statement']
                    processed_scenarios.add(key)

                # Find the last processed index
                last_processed_index = args.start - 1
                for i in range(args.start, min(args.end, len(data))):
                    df = data[i]
                    key = df['scenario'] + df['statement']
                    if key in processed_scenarios:
                        last_processed_index = i
                    else:
                        break

                start_index = last_processed_index + 1

                if start_index < args.end:
                    logger.info(f"🔄 Resuming from checkpoint: Found processed data, total {len(existing_data)} scenarios")
                    logger.info(f"🔄 Continuing from index {start_index} (original plan: {args.start}-{args.end})")
                else:
                    logger.info(f"✅ All data processed, no further processing needed")
                    logger.info(f"💾 Result file: {out_path}")
                    return
            else:
                logger.info("📋 Output file exists but is empty, starting from scratch")

        except (json.JSONDecodeError, KeyError) as e:
            logger.warning(f"⚠️ Failed to read existing output file: {e}")
            logger.info("📋 Starting from scratch")
            out_objs = []
            start_index = args.start
    else:
        logger.info("📋 No existing output file found, starting from scratch")

    processed_count = len(out_objs)

    for i, df in enumerate(tqdm(data[start_index:args.end], desc='🚀🚀Processing progress', initial=start_index - args.start,
                                total=args.end - args.start)):
        actual_index = start_index + i
        if actual_index >= args.end:
            break
        logger.info(f"🚀 Processing scenario {i + 1}/{len(data)}: {df['scenario']}, statement: {df['statement']}")
        key = df['scenario'] + df['statement']
        if key not in df_factor_dict:
            logger.warning(f"No factors found for scenario: {df['scenario']}")
            continue

        start_time = time.time()

        # Extract data
        scenario = df['scenario']
        statement = df['statement']

        factor_data = df_factor_dict[key]
        
        # Extract clustered and unclustered factors
        clustering_applied = factor_data.get('clustering_applied', False)
        
        if clustering_applied and 'clustering_stats' in factor_data and factor_data['clustering_stats'].get('initial_clusters'):
            # Get both initial and pruned clusters from clustering results
            initial_clusters = factor_data['clustering_stats']['initial_clusters']
            pruned_clusters = factor_data['clustering_stats'].get('pruned_clusters', {})
            factor_statement_mapping = factor_data['factor_statement_mapping']
            # Get unclustered factors - directly use factors_before_clustering (all original factors without topics)
            unclustered_factors = factor_data.get('factors_before_clustering', [])
            
            initial_clustered_count = sum(len(factors) for factors in initial_clusters.values())
            pruned_clustered_count = sum(len(factors) for factors in pruned_clusters.values())
            unclustered_count = len(unclustered_factors)
            
            logger.info(f"📊 Clustering applied: ")
            logger.info(f"  - Initial clusters: {initial_clustered_count} factors ({len(initial_clusters)} clusters) → Hierarchical processing")
            logger.info(f"  - Pruned clusters: {pruned_clustered_count} factors ({len(pruned_clusters)} clusters) → Hierarchical processing")
            logger.info(f"  - Original factors: {unclustered_count} items → Direct matching (no hierarchy)")
        else:
            # If clustering was not applied, treat all factors as unclustered
            initial_clusters = {}
            pruned_clusters = {}
            factor_statement_mapping = {}
            unclustered_factors = factor_data.get('factors_before_clustering', [])
            logger.info(f"📋 No clustering applied: {len(unclustered_factors)} original factors → Direct matching (no hierarchy)")

        # Get conditions based on dataset
        if args.dataset_name == 'common2sense':
            conditions = df['added_information'] + df['oppo_added_information']
        else:
            conditions = df['additional_sentences']
        

        logger.debug(f"🔧 Debug mode: Limiting processing to first {len(conditions)} conditions")

        # Skip if no conditions or factors
        if not conditions:
            logger.warning(f"⚠️ Skipping scenario (no conditions): {df['scenario']}")
            continue

        if not initial_clusters and not pruned_clusters and not unclustered_factors:
            logger.warning(f"⚠️ Skipping scenario (no factors): {df['scenario']}")
            continue

        # Initialize mapping results
        initial_clustered_mapping_results = {}
        pruned_clustered_mapping_results = {}
        unclustered_mapping_results = {}

        # Process pruned clusters mapping (hierarchical)
        if pruned_clusters:
            logger.info(f"🎯 Starting Pruned clusters hierarchical mapping ({len(conditions)} conditions)...")
            mapper.fit(pruned_clusters, None)  # Only fit on clusters, no unclustered factors
            for condition in tqdm(conditions, desc="Pruned clusters conditions processing"):
                logger.debug(f"  🔄 Processing condition {i+1}/{len(conditions)}: '{condition[:50]}...'")
                result = mapper.map_condition_clustered(statement,condition)
                for factor in result.get('candidate_factors', []):
                    result['candidate_factor_statement_mapping'][factor] = factor_statement_mapping.get(factor, 'unmapped')
                pruned_clustered_mapping_results[condition] = result
                logger.debug(f"  ✅ Condition {i+1} complete, selected factors: {len(result.get('final_factors', []))}")
            logger.info(f"✅ Pruned clusters hierarchical mapping complete")

        # Prepare output
        output_obj = {
            'scenario': scenario,
            'statement': statement,
            'conditions': conditions,
            'pruned_clusters': pruned_clusters,
            'clustering_applied': clustering_applied,
            'pruned_clustered_mapping_results': pruned_clustered_mapping_results,
            'processing_time': time.time() - start_time,
            'parameters': {
                'K1': args.K1,
                'K2': args.K2,
                'num_prompts': args.num_prompts,
                'vote_ratio': args.vote_ratio,
                'alpha': args.alpha,
                'max_retries': args.max_retries,
                'enable_preprocessing': args.enable_preprocessing
            },
        }

        out_objs.append(output_obj)
        processed_count += 1

        with open(out_path, 'w', encoding='utf-8') as f:
            json.dump(out_objs, f, ensure_ascii=False, indent=2)

        # Log processing statistics
        total_conditions = len(conditions)
        initial_clustered_mapped = sum(1 for result in initial_clustered_mapping_results.values() if result['final_factors']) if initial_clustered_mapping_results else 0
        pruned_clustered_mapped = sum(1 for result in pruned_clustered_mapping_results.values() if result['final_factors']) if pruned_clustered_mapping_results else 0
        unclustered_mapped = sum(1 for result in unclustered_mapping_results.values() if result['final_factors'])

        logger.info(f"💾 Saving intermediate results: Processed {processed_count} scenarios")
        logger.info(f"📈 Current scenario statistics: {total_conditions} conditions")
        logger.info(f"  - Initial clusters hierarchical mapping successful: {initial_clustered_mapped} items")
        logger.info(f"  - Pruned clusters hierarchical mapping successful: {pruned_clustered_mapped} items") 
        logger.info(f"  - Original factors direct mapping successful: {unclustered_mapped} items")
        logger.info(f"  - Total {len([r for r in [initial_clustered_mapping_results, pruned_clustered_mapping_results, unclustered_mapping_results] if r])} mapping methods processed")

    # Final save
    with open(out_path, 'w', encoding='utf-8') as f:
        json.dump(out_objs, f, ensure_ascii=False, indent=2)

    logger.info(f"🎉 Processing complete! Total {processed_count} scenarios processed")
    logger.info(f"💾 Results saved to: {out_path}")


if __name__ == '__main__':
    main()