import sys
import os
import json
import time
import argparse
import re
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import umap.umap_ as umap
import hdbscan
from sentence_transformers import SentenceTransformer

from utils import config
from utils.utils import ask_gpt, print_args
from utils.prompt_shots import FEW_SHOT_GEN, FEW_SHOT_VOTE, FEW_SHOT_EXT, FEW_SHOT_CLUSTER, FEW_SHOT_PRUNE, \
    JSON_CORRECTION_PROMPT, CLUSTER_NAME_PROMPT
from utils.response_parse_utils import safe_json_parse,extract_choice_from_response



class IterativeFactorGenerator:
    def __init__(self, model_name, target_count=100, batch_size=10, max_rounds=20, do_cluster=False):
        self.model = model_name
        self.target_count = target_count
        self.batch_size = batch_size
        self.max_rounds = max_rounds
        self.do_cluster = do_cluster
        self.factors = set()
        self.max_retries=5
        # Store additional information for output
        self.generated_sentences = []
        self.voting_records = {}
        self.pruning_records = {}  # Store pruning raw responses
        
        # Initialize embedding model for clustering
        if do_cluster:
            print("[Init] Loading embedding model for clustering...")
            try:
                # Try local path first
                local_path = config.MiniLM_L6_PATH
                if os.path.exists(local_path) and os.path.isdir(local_path):
                    print(f"[Init] Trying local model path: {local_path}")
                    self.embedding_model = SentenceTransformer(local_path)
                else:
                    raise FileNotFoundError("Local model not found")
            except Exception as e:
                print(f"[Init] Local model failed ({e}), downloading from Hugging Face...")
                # Fallback to downloading from Hugging Face
                self.embedding_model = SentenceTransformer(config.MiniLM_L6_PATH)
            print("[Init] Embedding model loaded successfully")

    def generate_batch_sentences(self, scenario, statement, oppo_statement, n):
        print(f"\n[Generate] {n} sentences for scenario: '{scenario}'")
        msgs = [
            {'role': 'system', 'content': (
                f"Generate {n} diverse supporting or refuting sentences for scenario: {scenario}, "
                f"comparing '{statement}' vs '{oppo_statement}'."
            )}
        ]
        msgs += FEW_SHOT_GEN
        msgs.append({'role': 'user', 'content': (
            f"Scenario: {scenario}\nStatement: {statement}\nOpposite: {oppo_statement}\nGenerate {n} sentences."
        )})
        resp = ask_gpt(msgs, model_name=self.model, max_token=512)
        lines = [l.strip() for l in resp.split("\n") if l.strip()]
        print(f"[Generate] Sentences: {lines}")
        # Store generated sentences
        self.generated_sentences.extend(lines)
        return lines

    def extract_factors(self, sentences):
        print("[Extract] Extracting factors from sentences...")
        msgs = [
            {'role': 'system', 'content': 'Extract distinct factors from these sentences. Think step by step about what factors are mentioned, then provide your final answer as a JSON array.'}
        ]
        msgs += FEW_SHOT_EXT
        msgs.append({'role': 'user', 'content': (
            f"Extract distinct factors from these sentences. Think about what key elements or aspects are mentioned, "
            f"then provide your final answer as a JSON array.\n\n"
            f"{chr(10).join(f'{i + 1}. {s}' for i, s in enumerate(sentences))}")})
        resp = ask_gpt(msgs, model_name=self.model, max_token=1024)

        # Use safe_json_parse which handles all parsing logic internally
        factors_list = safe_json_parse(resp, default_value=[], context="Factor extraction")
        factors = set(factors_list) if isinstance(factors_list, list) else set()
        print(f"[Extract] Factors: {factors}")
        return factors

    def vote_single_factor_with_retry(self, factor, scenario, stmt1, stmt2):
        """
        Vote on a single factor with retry mechanism for JSON parsing failures.
        """
        msgs = FEW_SHOT_VOTE.copy()
        msgs.insert(0, {
            'role': 'system',
            'content': (
                "Decide which statement the factor supports. "
                "Reason briefly (1–2 sentences), then provide your final answer as a JSON object. "
                "Keep the explanation as short as possible—no extra commentary."
            )
        })

        msgs.append({'role': 'user', 'content': (
            f"Statement1: {stmt1}\n"
            f"Statement2: {stmt2}\n"
            f"Factor: {factor}\n"
            f"Decide which statement this factor supports: Statement1, Statement2, or Both. "
            f"Think step by step about how this factor relates to each statement, then provide your final answer as a JSON object with the factor as key and choice as value.")})

        for retry in range(1, self.max_retries+1):
            resp = ask_gpt(msgs, model_name=self.model, max_token=256)
            print(f"[Vote] Response for factor '{factor}' (attempt {retry}): {resp}")

            # Try to parse JSON (safe_json_parse handles all parsing logic internally)
            parsed_response = safe_json_parse(resp, default_value={},
                                              context=f"Voting for factor '{factor}', attempt {retry}")

            if isinstance(parsed_response, dict) and parsed_response:
                choice = extract_choice_from_response(parsed_response, factor)
                if choice in ['Statement1', 'Statement2', 'Both']:
                    return choice, resp

            # If parsing failed and we have retries left, add correction prompt
            if retry == 2:
                print(f"[Vote] JSON parsing failed, retrying with correction prompt...")
                msgs.append({'role': 'assistant', 'content': resp})
                msgs.append(JSON_CORRECTION_PROMPT)
                msgs.append({'role': 'user', 'content': (
                    f"Please return a valid JSON object for this factor: {factor}\n"
                    f"Choose Statement1, Statement2, or Both.")})

        print(f"[Vote] All retries failed for factor '{factor}', using default 'Neutral'")
        return 'Neutral', resp


    def vote_factors(self, factors, scenario, stmt1, stmt2):
        print("[Vote] Voting on factor support with majority vote (3 calls per factor)...")
        mapping = {}
        conversion_log = {}  # Track Both→Neutral conversions

        for f in tqdm(factors, desc="Voting factors", unit="factor"):
            votes = []
            vote_details = []

            for attempt in range(1):
                choice, resp = self.vote_single_factor_with_retry(f, scenario, stmt1, stmt2)

                votes.append(choice)
                vote_details.append({
                    'attempt': attempt + 1,
                    'response': resp,
                    'choice': choice
                })
                print(f"[Vote] Attempt {attempt + 1} for factor '{f}' => {choice}")

            # Majority vote
            final_choice = max(set(votes), key=votes.count)

            # Convert "Both" to "Neutral" as per requirement
            original_choice = final_choice
            if final_choice == 'Both':
                final_choice = 'Neutral'
                conversion_log[f] = {
                    'original': 'Both',
                    'converted_to': 'Neutral',
                    'votes': votes
                }
                print(f"[Vote] Factor '{f}' converted: Both → Neutral")

            print(f"[Vote] Factor '{f}' final choice: {final_choice} (original: {original_choice})")
            mapping[f] = final_choice

            # Store voting records
            self.voting_records[f] = {
                'votes': votes,
                'original_choice': original_choice,
                'final_choice': final_choice,
                'vote_details': vote_details,
                'was_converted': f in conversion_log
            }

        # Log conversion summary
        if conversion_log:
            print(f"[Vote] Conversion Summary: {len(conversion_log)} factors converted from 'Both' to 'Neutral'")
            for factor, info in conversion_log.items():
                print(f"  - {factor}: votes={info['votes']}")

        return mapping

    def cluster_factors(self, factors):
        print("[Cluster] Clustering factors using UMAP + HDBSCAN...")
        factor_list = list(factors)

        if len(factor_list) < 2:
            print("[Cluster] Too few factors for clustering, returning single cluster")
            return {'cluster_0': factor_list}

        try:
            # Step 1: Generate embeddings using all-MiniLM-L6-v2
            print(f"[Cluster] Generating embeddings for {len(factor_list)} factors...")
            embeddings = self.embedding_model.encode(factor_list, show_progress_bar=True)
            print(f"[Cluster] Embeddings shape: {embeddings.shape}")
            del self.embedding_model
            import gc
            gc.collect()
            # Step 2: UMAP dimensionality reduction
            # Reduce to 10-50 dimensions, dynamically choose based on factor count
            n_factors = len(factor_list)
            umap_dims = min(50, max(10, n_factors // 5))  # 10-50 dims, scale with factor count
            umap_dims = min(umap_dims, n_factors - 1)  # Cannot exceed n_factors - 1

            print(f"[Cluster] Applying UMAP reduction to {umap_dims} dimensions...")
            umap_reducer = umap.UMAP(
                n_components=umap_dims,
                n_neighbors=min(15, n_factors - 1),  # Adjust neighbors based on factor count
                min_dist=0.1,
                metric='cosine',
                random_state=42
            )
            reduced_embeddings = umap_reducer.fit_transform(embeddings)
            print(f"[Cluster] UMAP reduced embeddings shape: {reduced_embeddings.shape}")

            # Step 3: HDBSCAN clustering
            min_cluster_size = max(2, n_factors // 20)  # ~5% of factors, minimum 2
            print(f"[Cluster] Applying HDBSCAN with min_cluster_size={min_cluster_size}...")
            clusterer = hdbscan.HDBSCAN(
                min_cluster_size=min_cluster_size,
                min_samples=1,
                metric='euclidean',
                cluster_selection_epsilon=0.0
            )
            cluster_labels = clusterer.fit_predict(reduced_embeddings)

            # Organize factors by cluster
            clusters_dict = defaultdict(list)
            noise_factors = []

            for factor, label in zip(factor_list, cluster_labels):
                if label == -1:  # Noise point
                    noise_factors.append(factor)
                else:
                    clusters_dict[f'cluster_{label}'].append(factor)

            # Add noise as separate cluster if exists
            if noise_factors:
                clusters_dict['noise'] = noise_factors

            print(f"[Cluster] Found {len(clusters_dict)} clusters:")
            for cluster_name, cluster_factors in clusters_dict.items():
                print(f"  - {cluster_name}: {len(cluster_factors)} factors")

            # Step 4: Use LLM to generate meaningful cluster names
            final_clusters = {}
            for cluster_name, cluster_factors in clusters_dict.items():
                if cluster_name == 'noise':
                    final_clusters['uncategorized'] = cluster_factors
                else:
                    # Generate meaningful cluster name using LLM
                    theme_name = self._generate_cluster_theme(cluster_factors)
                    final_clusters[theme_name] = cluster_factors

            print(f"[Cluster] Final clusters with themes: {list(final_clusters.keys())}")
            return final_clusters

        except Exception as e:
            print(f"[Cluster] Error in UMAP+HDBSCAN clustering: {e}")
            print("[Cluster] Falling back to single cluster")
            return {'all_factors': factor_list}

    def _generate_cluster_theme(self, cluster_factors):
        """Generate a meaningful theme name for a cluster using LLM"""
        if len(cluster_factors) < 2:
            return f"theme_{hash(tuple(sorted(cluster_factors))) % 10000}"

        try:
            msgs = CLUSTER_NAME_PROMPT.copy()
            msgs.insert(0, {'role': 'system',
                            'content': 'Generate a concise English theme name (1-3 words) that captures the common topic of these factors. Return only the theme name, no explanation.'},
                )
            msgs.append(
           {'role': 'user',
            'content': f"Generate a theme name for these related factors:\n{json.dumps(cluster_factors[:10], ensure_ascii=False)}"}  # Limit to first 10 for context
            )

            resp = ask_gpt(msgs, model_name=self.model, max_token=50)
            theme = resp.strip().strip('"').strip("'")

            # Remove special tokens and unwanted patterns
            special_tokens = [
                'eot_id', 'start_header_id', 'end_header_id', 'user', 'assistant', 
                'system', 'bot_id', 'human_id', '<|', '|>', '</', '/>', '###'
            ]
            
            # First, remove any content after special tokens
            for token in special_tokens:
                if token in theme:
                    theme = theme.split(token)[0]
                    break
            
            # Clean up theme name - keep only alphanumeric, spaces, underscores, and hyphens
            theme = re.sub(r'[^a-zA-Z0-9\s_-]', '', theme)
            theme = '_'.join(theme.split())
            
            # Remove any remaining underscores at start/end and limit length
            theme = theme.strip('_')[:50]

            if not theme or len(theme) < 2:
                theme = f"cluster_{hash(tuple(sorted(cluster_factors))) % 10000}"

            print(f"[Cluster] Generated theme '{theme}' for {len(cluster_factors)} factors")
            return theme

        except Exception as e:
            print(f"[Cluster] Failed to generate theme name: {e}")
            return f"cluster_{hash(tuple(sorted(cluster_factors))) % 10000}"

    def prune_redundancy(self, clusters, mapping):
        print("[Prune] Pruning redundant factors in each cluster...")
        pruned = []
        pruned_clusters = {}  # Store clusters after pruning
        
        for cname, facs in clusters.items():
            print(f"[Prune] Cluster '{cname}': {facs}")
            msgs = FEW_SHOT_PRUNE.copy()
            msgs.insert(0, {'role': 'system', 'content': f"From these factors in cluster '{cname}', remove redundant ones. Think about which factors are essentially the same concept, then return a JSON array of the remaining unique factors."})
            msgs.append({'role': 'user', 'content': (
                f"From these factors in cluster '{cname}', remove redundant ones. "
                f"Think about which factors are essentially the same concept, then return a JSON array of the remaining unique factors.\n\n"
                f"{json.dumps(facs, ensure_ascii=False)}")})
            resp = ask_gpt(msgs, model_name=self.model, max_token=512)

            # Store the raw response for this cluster
            self.pruning_records[cname] = {
                'original_factors': facs,
                'raw_response': resp,
                'prompt_used': msgs[-1]['content']
            }

            # Use unified JSON parser
            keep = safe_json_parse(resp, default_value=facs, context=f"Pruning cluster '{cname}'")
            if not isinstance(keep, list):
                print(f"[Prune] Warning: Expected list but got {type(keep)}, using original factors")
                keep = facs
            
            # Update pruning record with parsed result
            self.pruning_records[cname]['parsed_factors'] = keep
            self.pruning_records[cname]['factors_removed'] = list(set(facs) - set(keep))
            self.pruning_records[cname]['factors_kept_count'] = len(keep)
            self.pruning_records[cname]['factors_removed_count'] = len(set(facs) - set(keep))
            
            print(f"[Prune] Kept after pruning: {keep}")
            pruned += keep
            
            # Store pruned cluster (only if it has factors left)
            if keep:
                pruned_clusters[cname] = keep
                
        return set(pruned), pruned_clusters

    def run(self, scenario, statement, oppo_statement=''):
        print(f"\n=== Start instance: '{scenario}' ===")
        for rnd in tqdm(range(self.max_rounds), desc="Gen rounds", unit="round"):
            if len(self.factors) >= self.target_count:
                break
            sents = self.generate_batch_sentences(scenario, statement, oppo_statement, self.batch_size)
            newf = self.extract_factors(sents)
            before = len(self.factors)
            self.factors |= newf
            after = len(self.factors)
            print(f"[Round] {rnd + 1}: +{after - before} factors, total={after}")
        print(f"[Result] Collected {len(self.factors)} factors.")

        # Voting with Both→Neutral conversion
        mapping = self.vote_factors(self.factors, scenario, statement, oppo_statement)
        factors_after_voting = {f for f, v in mapping.items() if v in ('Statement1', 'Statement2','Neutral')}
        print(f"[Vote] After filter: {len(factors_after_voting)} factors remain.")

        # Store factors before clustering
        factors_before_clustering = list(factors_after_voting)

        # Create factor mapping for final results (factor -> statement)
        final_factor_mapping = {f: v for f, v in mapping.items() if f in factors_after_voting}

        # Optional clustering prune
        factors_after_clustering = factors_after_voting
        initial_clusters = None
        pruned_clusters = None  # Initialize pruned_clusters
        clustering_stats = {
            'applied': self.do_cluster,
            'factors_before_count': len(factors_before_clustering),
            'factors_after_count': 0,
            'factors_removed_count': 0,
            'factors_removed': [],
            'reduction_rate': 0.0,
            'initial_clusters': None,
            'cluster_count': 0
        }

        if self.do_cluster:
            initial_clusters = self.cluster_factors(factors_after_voting)
            factors_after_clustering, pruned_clusters = self.prune_redundancy(initial_clusters, mapping)
            print(f"[Prune] After clustering prune: {len(factors_after_clustering)} factors remain.")

            # Calculate clustering statistics
            factors_removed = set(factors_before_clustering) - set(factors_after_clustering)
            clustering_stats.update({
                'factors_after_count': len(factors_after_clustering),
                'factors_removed_count': len(factors_removed),
                'reduction_rate': len(factors_removed) / len(factors_before_clustering) if factors_before_clustering else 0.0,
                'cluster_count': len(initial_clusters) if initial_clusters else 0,
                'factors_removed': list(factors_removed),
                'initial_clusters': initial_clusters
            })

            print(f"[Clustering Stats] Before: {clustering_stats['factors_before_count']}, "
                  f"After: {clustering_stats['factors_after_count']}, "
                  f"Removed: {clustering_stats['factors_removed_count']}, "
                  f"Reduction: {clustering_stats['reduction_rate']:.2%}, "
                  f"Clusters: {clustering_stats['cluster_count']}")
        else:
            clustering_stats['factors_after_count'] = len(factors_before_clustering)

        print(f"=== End instance: '{scenario}', final factors count: {len(factors_after_clustering)} ===\n")

        # Return comprehensive results with both factor sets and statistics
        return {
            'factors_before_clustering': factors_before_clustering,
            'factors_after_clustering': list(factors_after_clustering),
            'factor_statement_mapping': final_factor_mapping,  # factor -> statement mapping
            'generated_sentences': self.generated_sentences,
            'voting_records': self.voting_records,
            'pruning_records': self.pruning_records,  # Add pruning raw responses
            'clustering_applied': self.do_cluster,
            'clustering_stats': clustering_stats,
            'pruned_clusters': pruned_clusters
        }


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default=config.model_name)
    parser.add_argument("--dataset_name", type=str, default=config.dataset_name)
    parser.add_argument("--dataset_file_dic", type=str, default=config.dataset_file_dic)
    parser.add_argument("--save_file_dic", type=str, default=config.save_file_dic)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=1000)
    parser.add_argument("--target", type=int, default=40)
    parser.add_argument("--batch", type=int, default=5)
    parser.add_argument("--rounds", type=int, default=10)
    # parser.add_argument("--target", type=int, default=60)
    # parser.add_argument("--batch", type=int, default=6)
    # parser.add_argument("--rounds", type=int, default=15)
    parser.add_argument("--cluster", action='store_true', help="enable clustering prune")
    parser.add_argument("--no-cluster", dest='cluster', action='store_false', help="disable clustering prune")
    parser.set_defaults(cluster=True)
    args = parser.parse_args()
    print_args(args)
    return args


if __name__ == '__main__':
    args = parse_args()
    infile = os.path.join(args.dataset_file_dic, args.dataset_name + '.json')
    print('input file--------->: {}'.format(infile))
    with open(infile, 'r', encoding='utf-8') as f:
        data = json.load(f)
    # suffix = '_early_stop_vote'
    suffix = ''
    os.makedirs(args.save_file_dic, exist_ok=True)
    out_objs = []
    out_path = os.path.join(
        args.save_file_dic,
        f"{args.dataset_name}_{args.model_name.replace(':', '-')}_{args.start}_{args.end}_factors{suffix}.json"
    )
    print('output file---------->: {}'.format(out_path))
    if os.path.exists(out_path):
        with open(out_path, 'r', encoding='utf-8') as f:
            out_objs = json.load(f)
        start_index = args.start + len(out_objs)  
        print(f"Loaded {len(out_objs)} existing records. Continuing from index {start_index}")
    else:
        out_objs = []
        start_index = args.start


    if start_index >= args.end:
        print(f"All records from {args.start} to {args.end} have been processed. Exiting.")
        exit(0)

    for idx, item in tqdm(enumerate(data[start_index:args.end], start=start_index),
                          total=min(len(data), args.end) - start_index,
                          desc="Processing instances", unit="instance"):
        start_time = time.time()  # Record start time for this iteration

        generator = IterativeFactorGenerator(
            args.model_name, args.target, args.batch, args.rounds, args.cluster
        )

        results = generator.run(item['scenario'], item['statement'], item['opposite_statement'])

        end_time = time.time()  # Record end time for this iteration
        elapsed_time = end_time - start_time  # Calculate elapsed time in seconds

        out_objs.append({
            'scenario': item['scenario'],
            'statement': item['statement'],
            'opposite_statement': item['opposite_statement'],
            'factors_before_clustering': results['factors_before_clustering'],
            'factors_after_clustering': results['factors_after_clustering'],
            'factor_statement_mapping': results['factor_statement_mapping'],
            'generated_sentences': results['generated_sentences'],
            'voting_records': results['voting_records'],
            'pruning_records': results['pruning_records'],  # Add pruning raw responses
            'clustering_applied': results['clustering_applied'],
            'clustering_stats': results['clustering_stats'],
            'pruned_clusters': results['pruned_clusters'],
            'elapsed_time': elapsed_time  # Add elapsed time field
        })

        with open(out_path, 'w', encoding='utf-8') as outf:
            json.dump(out_objs, outf, indent=2, ensure_ascii=False)
        print(f"[Saved] index {idx} to {out_path}")