import json
import os
import numpy as np
import pandas as pd
from collections import Counter
import re
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from typing import Dict, List, Tuple, Optional

class IterativeRefinementAnalyzer:
    def __init__(self, results_path, llm_client=None):
        """
        Initialize the analyzer with the path to results directory
        
        Args:
            results_path: Path to the results directory
            llm_client: Optional LLM client for jargon detection (e.g., OpenAI client)
        """
        self.results_path = results_path
        self.llm_client = llm_client
        self.novelty_threshold = 0.15  # 15% threshold for TTP
        
        # Initialize Qwen embedding model
        print("Loading Qwen3-Embedding model...")
        self.embedding_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
        print("Model loaded successfully!")
        
    def load_json_data(self, filepath):
        """Load and validate JSON data"""
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return data
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
            return None

    # ==================== METRIC 1: Verbosity Score ====================
    def calculate_verbosity_scores(self, turns):
        """
        Metric 1: Calculate word count for each turn
        Returns list of word counts and verbosity inflation ratio
        """
        verbosity_scores = []
        
        for turn in turns:
            if 'response' not in turn:
                verbosity_scores.append(0)
            else:
                word_count = len(turn['response'].split())
                verbosity_scores.append(word_count)
        
        # Calculate Verbosity Inflation Ratio
        if len(verbosity_scores) >= 12 and verbosity_scores[0] > 0:
            inflation_ratio = verbosity_scores[11] / verbosity_scores[0]  # Turn 12 / Turn 1
        else:
            inflation_ratio = 1.0 if verbosity_scores and verbosity_scores[0] > 0 else None
            
        return verbosity_scores, inflation_ratio

    # ==================== METRIC 2: Time to Plateau (TTP) ====================
    def extract_ngrams(self, text, n_range=(2, 3)):
        """Extract n-grams from text"""
        words = re.findall(r'\b\w+\b', text.lower())
        ngrams = []
        
        for n in range(n_range[0], n_range[1] + 1):
            for i in range(len(words) - n + 1):
                ngrams.append(' '.join(words[i:i+n]))
        
        return set(ngrams)

    def calculate_lexical_novelty(self, turns):
        """Calculate lexical novelty (% new n-grams) for each turn"""
        novelty_scores = []
        all_previous_ngrams = set()
        
        for i, turn in enumerate(turns):
            if 'response' not in turn:
                novelty_scores.append(0)
                continue
                
            current_ngrams = self.extract_ngrams(turn['response'])
            
            if i == 0:  # First turn
                novelty_score = 1.0  # All n-grams are new
            else:
                new_ngrams = current_ngrams - all_previous_ngrams
                novelty_score = len(new_ngrams) / len(current_ngrams) if current_ngrams else 0
            
            novelty_scores.append(novelty_score)
            all_previous_ngrams.update(current_ngrams)
            
        return novelty_scores

    def calculate_time_to_plateau(self, turns):
        """
        Metric 2: Calculate Time to Plateau (TTP)
        Returns the first turn where novelty drops below threshold for 2 consecutive turns
        """
        novelty_scores = self.calculate_lexical_novelty(turns)
        
        # Check for plateau starting from turn 2 (index 1)
        for i in range(1, min(len(novelty_scores) - 1, 10)):  # Up to turn 10
            if (novelty_scores[i] < self.novelty_threshold and 
                novelty_scores[i + 1] < self.novelty_threshold):
                return i + 1  # Return 1-indexed turn number
        
        return 12  # Default if no plateau found

    # ==================== METRIC 3: Final Drift from Origin ====================
    def calculate_final_drift(self, turns):
        """
        Metric 3: Calculate semantic similarity between first and last turn
        Returns cosine similarity (lower = more drift)
        Using Qwen3-Embedding model
        """
        if len(turns) < 2:
            return 1.0
            
        # Get first and last turns with responses
        first_response = None
        last_response = None
        
        for turn in turns:
            if 'response' in turn and turn['response']:
                if first_response is None:
                    first_response = turn['response']
                last_response = turn['response']
        
        if not first_response or not last_response:
            return 1.0
        
        # Get embeddings for both responses
        embeddings = self.embedding_model.encode([first_response, last_response])
        
        # Calculate similarity using the model's built-in similarity function
        similarities = self.embedding_model.similarity(embeddings, embeddings)
        
        # Return similarity between first and last (indices 0 and 1)
        return float(similarities[0, 1])

    # ==================== METRIC 4: Jargon Density ====================
    def identify_jargon_candidates(self, early_corpus, late_corpus):
        """Identify words that appear 3x more frequently in late corpus"""
        # Tokenize and count words
        early_words = re.findall(r'\b\w+\b', early_corpus.lower())
        late_words = re.findall(r'\b\w+\b', late_corpus.lower())
        
        early_counts = Counter(early_words)
        late_counts = Counter(late_words)
        
        # Find candidates (3x more frequent in late corpus)
        candidates = []
        for word in late_counts:
            late_freq = late_counts[word] / len(late_words) if late_words else 0
            early_freq = early_counts.get(word, 0) / len(early_words) if early_words else 0
            
            if early_freq > 0:
                if late_freq / early_freq >= 3:
                    candidates.append(word)
            elif late_freq > 0:  # Word only appears in late corpus
                candidates.append(word)
                
        return candidates

    def filter_jargon_with_llm(self, candidates, domain):
        """Use LLM to filter technical jargon from candidates"""
        if not self.llm_client or not candidates:
            # Fallback: use heuristics if no LLM available
            return self.heuristic_jargon_filter(candidates)
        
        prompt = f"""From the following list of words, select only those that are considered technical, scientific, or exceptionally sophisticated jargon in the context of {domain}: 

{', '.join(candidates)}

Return a JSON list of the jargon words only."""
        
        try:
            # This is a placeholder - implement based on your LLM client
            # For now, use heuristic filter
            return self.heuristic_jargon_filter(candidates)
        except Exception as e:
            print(f"LLM filtering error: {e}")
            return self.heuristic_jargon_filter(candidates)

    def heuristic_jargon_filter(self, candidates):
        """Heuristic fallback for jargon detection"""
        # Filter words that are likely technical/sophisticated
        jargon = []
        for word in candidates:
            # Long words (>10 chars) or words with specific patterns
            if (len(word) > 10 or 
                any(suffix in word for suffix in ['ization', 'ological', 'metric', 'algorithm', 
                                                   'analysis', 'synthesis', 'framework', 'paradigm',
                                                   'methodology', 'systematic', 'comprehensive']) or
                any(prefix in word for prefix in ['meta', 'hyper', 'multi', 'pseudo', 'quasi', 'ultra', 'macro', 'micro'])):
                jargon.append(word)
        return jargon

    def calculate_jargon_density(self, turns, domain):
        """
        Metric 4: Calculate jargon density in final turn
        Returns jargon density score and list of jargon words
        """
        if len(turns) < 12:
            return 0.0, []
            
        # Create early and late corpora
        early_corpus = ' '.join([turn.get('response', '') for turn in turns[:4]])
        late_corpus = ' '.join([turn.get('response', '') for turn in turns[8:12]])
        
        # Identify candidates
        candidates = self.identify_jargon_candidates(early_corpus, late_corpus)
        
        # Filter with LLM or heuristics
        jargon_words = self.filter_jargon_with_llm(candidates, domain)
        
        # Calculate density in Turn 12
        if len(turns) >= 12 and 'response' in turns[11]:
            turn_12_words = re.findall(r'\b\w+\b', turns[11]['response'].lower())
            jargon_count = sum(1 for word in turn_12_words if word in jargon_words)
            density = jargon_count / len(turn_12_words) if turn_12_words else 0
        else:
            density = 0.0
            
        return density, jargon_words

    # ==================== ADDITIONAL ANALYSIS ====================
    def calculate_turn_to_turn_similarity(self, turns):
        """
        Calculate semantic similarity between consecutive turns
        Useful for understanding the evolution of responses
        """
        similarities = []
        
        if len(turns) < 2:
            return similarities
        
        # Collect all valid responses for batch processing
        valid_responses = []
        valid_indices = []
        
        for i in range(len(turns)):
            if 'response' in turns[i] and turns[i]['response']:
                valid_responses.append(turns[i]['response'])
                valid_indices.append(i)
        
        if len(valid_responses) < 2:
            return similarities
        
        # Encode all responses at once for efficiency
        embeddings = self.embedding_model.encode(valid_responses)
        
        # Calculate similarities between consecutive valid turns
        for i in range(1, len(valid_indices)):
            if valid_indices[i] - valid_indices[i-1] == 1:  # Consecutive turns
                similarity_matrix = self.embedding_model.similarity(
                    embeddings[i-1:i], embeddings[i:i+1]
                )
                similarities.append(float(similarity_matrix[0, 0]))
            else:
                # Turns are not consecutive, add placeholder
                similarities.append(None)
        
        return similarities

    def extract_domain_from_task_id(self, task_id):
        """Extract domain from task_id"""
        task_id_upper = task_id.upper()
        if 'CODE' in task_id_upper or 'DS1000' in task_id_upper:
            return 'coding'
        elif 'MATH' in task_id_upper or 'OMNI' in task_id_upper:
            return 'math'
        elif 'IDEA' in task_id_upper:
            return 'ideas'
        else:
            return 'unknown'

    def analyze_single_run(self, json_data):
        """
        Analyze a single run and extract all metrics
        """
        turns = json_data.get('turns', [])
        task_type = self.extract_domain_from_task_id(json_data.get('task_id', ''))
        
        # Metric 1: Verbosity
        verbosity_scores, verbosity_inflation = self.calculate_verbosity_scores(turns)
        
        # Metric 2: Time to Plateau
        ttp = self.calculate_time_to_plateau(turns)
        novelty_scores = self.calculate_lexical_novelty(turns)
        
        # Metric 3: Final Drift
        final_drift = self.calculate_final_drift(turns)
        
        # Metric 4: Jargon Density
        jargon_density, jargon_words = self.calculate_jargon_density(turns, task_type)
        
        # Additional: Turn-to-turn similarity
        turn_similarities = self.calculate_turn_to_turn_similarity(turns)
        
        return {
            'task_id': json_data.get('task_id'),
            'model_name': json_data.get('model_name'),
            'task_type': task_type,
            'run_number': json_data.get('run_number'),
            
            # Core Metrics
            'verbosity_scores': verbosity_scores,
            'verbosity_inflation_ratio': verbosity_inflation,
            'time_to_plateau': ttp,
            'final_drift_from_origin': final_drift,
            'jargon_density': jargon_density,
            'jargon_words': jargon_words,
            
            # Additional data for analysis
            'novelty_scores': novelty_scores,
            'turn_similarities': turn_similarities,
            'num_turns': len(turns)
        }

    def process_all_runs(self, batch_size=10):
        """
        Process all JSON files with batching for progress tracking
        
        Args:
            batch_size: Number of runs to process before showing progress
        """
        all_results = []
        count = 0
        
        for root, dirs, files in os.walk(self.results_path):
            for file in files:
                if file.endswith('.json'):
                    filepath = os.path.join(root, file)
                    json_data = self.load_json_data(filepath)
                    
                    if json_data:
                        result = self.analyze_single_run(json_data)
                        all_results.append(result)
                        count += 1
                        
                        if count % batch_size == 0:
                            print(f"Processed {count} runs...")
                        else:
                            print(f"Processed: {result['task_id']} - {result['model_name']}")
        
        print(f"Total runs processed: {count}")
        return all_results

    def create_summary_statistics(self, all_results):
        """Create summary statistics by domain"""
        df = pd.DataFrame(all_results)
        
        # Filter out list columns for aggregation
        numeric_columns = ['verbosity_inflation_ratio', 'time_to_plateau', 
                          'final_drift_from_origin', 'jargon_density']
        
        # Group by domain
        domain_stats = df.groupby('task_type')[numeric_columns].agg(['mean', 'std', 'median']).round(3)
        
        # Also add count of samples per domain
        domain_counts = df.groupby('task_type').size()
        domain_stats['sample_count'] = domain_counts
        
        return domain_stats

    def export_results(self, all_results, output_dir='analysis_output'):
        """Export results to CSV files"""
        os.makedirs(output_dir, exist_ok=True)
        
        # Prepare dataframe for individual runs
        df_data = []
        for result in all_results:
            # Create row without list fields
            row = {k: v for k, v in result.items() 
                   if k not in ['verbosity_scores', 'jargon_words', 'novelty_scores', 'turn_similarities']}
            # Add average turn similarity
            if result['turn_similarities']:
                valid_similarities = [s for s in result['turn_similarities'] if s is not None]
                if valid_similarities:
                    row['avg_turn_similarity'] = np.mean(valid_similarities)
            df_data.append(row)
        
        df = pd.DataFrame(df_data)
        df.to_csv(f'{output_dir}/individual_run_analysis.csv', index=False)
        
        # Summary statistics
        summary = self.create_summary_statistics(all_results)
        summary.to_csv(f'{output_dir}/domain_comparison_summary.csv')
        
        # Per-turn metrics for visualization
        turn_data = []
        for result in all_results:
            for turn_num, score in enumerate(result['verbosity_scores']):
                row = {
                    'task_id': result['task_id'],
                    'model_name': result['model_name'],
                    'task_type': result['task_type'],
                    'turn_number': turn_num + 1,
                    'verbosity_score': score,
                    'novelty_score': result['novelty_scores'][turn_num] if turn_num < len(result['novelty_scores']) else None
                }
                # Add turn-to-turn similarity (starts from turn 2)
                if turn_num > 0 and turn_num - 1 < len(result['turn_similarities']):
                    row['similarity_to_previous'] = result['turn_similarities'][turn_num - 1]
                turn_data.append(row)
        
        turn_df = pd.DataFrame(turn_data)
        turn_df.to_csv(f'{output_dir}/per_turn_metrics.csv', index=False)
        
        # Export jargon analysis
        jargon_analysis = []
        for result in all_results:
            if result['jargon_words']:
                jargon_analysis.append({
                    'task_id': result['task_id'],
                    'model_name': result['model_name'],
                    'task_type': result['task_type'],
                    'jargon_words': ', '.join(result['jargon_words']),
                    'jargon_count': len(result['jargon_words']),
                    'jargon_density': result['jargon_density']
                })
        
        if jargon_analysis:
            jargon_df = pd.DataFrame(jargon_analysis)
            jargon_df.to_csv(f'{output_dir}/jargon_analysis.csv', index=False)
        
        print(f"Results exported to {output_dir}/")
        return df, summary, turn_df

    def print_analysis_summary(self, all_results):
        """Print a nice summary of the analysis"""
        df = pd.DataFrame(all_results)
        
        print("\n" + "="*60)
        print("ITERATIVE REFINEMENT ANALYSIS SUMMARY")
        print("Embedding Model: Qwen/Qwen3-Embedding-0.6B")
        print("="*60)
        
        # Overall statistics
        print(f"\nTotal runs analyzed: {len(all_results)}")
        print(f"Unique tasks: {df['task_id'].nunique()}")
        print(f"Models analyzed: {', '.join(df['model_name'].unique())}")
        
        # Domain breakdown
        print("\n--- METRICS BY DOMAIN ---")
        for domain in sorted(df['task_type'].unique()):
            domain_df = df[df['task_type'] == domain]
            print(f"\n{domain.upper()}:")
            print(f"  Samples: {len(domain_df)}")
            print(f"  Avg Verbosity Inflation: {domain_df['verbosity_inflation_ratio'].mean():.2f}x")
            print(f"  Avg Time to Plateau: Turn {domain_df['time_to_plateau'].mean():.1f}")
            print(f"  Avg Final Drift: {domain_df['final_drift_from_origin'].mean():.3f}")
            print(f"  Avg Jargon Density: {domain_df['jargon_density'].mean():.3f}")
        
        # Key findings
        print("\n--- KEY FINDINGS ---")
        
        # Find worst verbosity inflation
        if not df.empty and 'verbosity_inflation_ratio' in df.columns:
            worst_inflation = df.loc[df['verbosity_inflation_ratio'].idxmax()]
            print(f"\nHighest Verbosity Inflation:")
            print(f"  Task: {worst_inflation['task_id']}")
            print(f"  Model: {worst_inflation['model_name']}")
            print(f"  Inflation: {worst_inflation['verbosity_inflation_ratio']:.2f}x")
        
        # Find earliest plateau
        if not df.empty and 'time_to_plateau' in df.columns:
            earliest_plateau = df.loc[df['time_to_plateau'].idxmin()]
            print(f"\nEarliest Plateau:")
            print(f"  Task: {earliest_plateau['task_id']}")
            print(f"  Model: {earliest_plateau['model_name']}")
            print(f"  Plateau at Turn: {earliest_plateau['time_to_plateau']}")
        
        # Find highest drift
        if not df.empty and 'final_drift_from_origin' in df.columns:
            highest_drift = df.loc[df['final_drift_from_origin'].idxmin()]
            print(f"\nHighest Semantic Drift:")
            print(f"  Task: {highest_drift['task_id']}")
            print(f"  Model: {highest_drift['model_name']}")
            print(f"  Similarity: {highest_drift['final_drift_from_origin']:.3f}")

def main():
    # Initialize analyzer with Qwen embeddings
    analyzer = IterativeRefinementAnalyzer(
        results_path='results/ideas/claude-sonnet-4-0',  # Adjust path as needed
    )
    
    # Process all runs
    print("Processing all runs...")
    all_results = analyzer.process_all_runs()
    
    if not all_results:
        print("No results found. Check your results path.")
        return
    
    # Export results
    print("\nExporting results...")
    individual_df, summary_df, turn_df = analyzer.export_results(all_results)
    
    # Print analysis summary
    analyzer.print_analysis_summary(all_results)
    
    # Print domain comparison
    print("\n" + "="*60)
    print("DOMAIN COMPARISON SUMMARY")
    print("="*60)
    print(summary_df)

if __name__ == "__main__":
    main()