import json
import pandas as pd
from typing import Dict, List, Any
import os
class DataCleaner:
    def __init__(self):
        self.stats = {
            "total_samples": 0,
            "empty_paraphrases": 0,
            "no_valid_paraphrases": 0,
            "kept_samples": 0,
            "avg_paraphrases_before": 0,
            "avg_paraphrases_after": 0,
            "length_filtered": 0,
            "samples_no_paraphrases_after_length_filter": 0
        }
    
    def is_empty_paraphrases(self, item: Dict) -> bool:
        """
        Check if paraphrases are empty
        
        Args:
            item: Data item
            
        Returns:
            True if paraphrases is empty or has no valid items
        """
        if "paraphrases" not in item:
            return True
        
        paraphrases = item["paraphrases"]
        
        if not paraphrases or len(paraphrases) == 0:
            return True
        
        valid_paraphrases = [p for p in paraphrases if p.get("status") == "valid"]
        if len(valid_paraphrases) == 0:
            return True
        
        return False
    
    def clean_individual_item(self, item: Dict, filter_by_length: bool = True, 
                             min_length_ratio: float = 0.7, max_length_ratio: float = 1.3) -> Dict:
        """
        Clean a single data item, keep only paraphrases with "valid" status, and optionally filter by length
        
        Args:
            item: Original data item
            filter_by_length: Whether to filter by length ratio
            min_length_ratio: Minimum length ratio (paraphrase length/original length)
            max_length_ratio: Maximum length ratio (paraphrase length/original length)
            
        Returns:
            Cleaned data item
        """
        if "paraphrases" not in item:
            return item
        
        original_text = self.get_original_text(item)
        original_length = len(original_text) if original_text else 0
        
        valid_paraphrases = [
            p for p in item["paraphrases"] 
            if p.get("status") == "valid"
        ]
        
        if filter_by_length and original_length > 0:
            length_filtered_paraphrases = []
            
            for p in valid_paraphrases:
                paraphrase_text = self.get_paraphrase_text(p)
                paraphrase_length = len(paraphrase_text) if paraphrase_text else 0
                
                if paraphrase_length > 0:
                    length_ratio = paraphrase_length / original_length
                    
                    if min_length_ratio <= length_ratio <= max_length_ratio:
                        length_filtered_paraphrases.append(p)
                    else:
                        self.stats["length_filtered"] += 1
                else:
                    length_filtered_paraphrases.append(p)
            
            valid_paraphrases = length_filtered_paraphrases
        
        cleaned_item = item.copy()
        cleaned_item["paraphrases"] = valid_paraphrases
        
        if "valid_paraphrases_count" in cleaned_item:
            cleaned_item["valid_paraphrases_count"] = len(valid_paraphrases)
        
        return cleaned_item
    
    def clean_json_data(self, input_file: str, output_file: str = None, 
                       min_valid_paraphrases: int = 1, clean_invalid: bool = True) -> Dict:
        """
        Clean data file in JSON format
        
        Args:
            input_file: Input file path
            output_file: Output file path (if None, add _cleaned to original filename)
            min_valid_paraphrases: Minimum number of valid paraphrases
            clean_invalid: Whether to remove paraphrases with "invalid" status
            
        Returns:
            Cleaned data and statistics
        """
        if output_file is None:
            name, ext = os.path.splitext(input_file)
            output_file = f"{name}_cleaned{ext}"
        
        print(f"🔄 Cleaning data: {input_file}")
        print(f"📤 Output file: {output_file}")
        print(f"🎯 Minimum valid paraphrases: {min_valid_paraphrases}")
        print("-" * 50)
        
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"❌ Failed to read file: {e}")
            return None
        
        if isinstance(data, dict) and "data" in data:
            items = data["data"]
            has_metadata = True
        elif isinstance(data, list):
            items = data
            has_metadata = False
        else:
            print("❌ Unsupported data format")
            return None
        
        self.stats["total_samples"] = len(items)
        
        total_paraphrases_before = 0
        for item in items:
            if "paraphrases" in item and item["paraphrases"]:
                total_paraphrases_before += len(item["paraphrases"])
        
        if self.stats["total_samples"] > 0:
            self.stats["avg_paraphrases_before"] = total_paraphrases_before / self.stats["total_samples"]
        
        cleaned_items = []
        
        for i, item in enumerate(items):
            if self.is_empty_paraphrases(item):
                if not item.get("paraphrases"):
                    self.stats["empty_paraphrases"] += 1
                else:
                    self.stats["no_valid_paraphrases"] += 1
                
                print(f"🗑️ Removing sample {i}: No valid paraphrases - {item.get('prompt', 'No prompt')[:50]}...")
                continue
            
            if clean_invalid:
                cleaned_item = self.clean_individual_item(item)
            else:
                cleaned_item = item
            
            valid_count = len([p for p in cleaned_item.get("paraphrases", []) if p.get("status") == "valid"])
            
            if valid_count < min_valid_paraphrases:
                print(f"🗑️ Removing sample {i}: Insufficient valid paraphrases ({valid_count}) - {item.get('prompt', 'No prompt')[:50]}...")
                continue
            
            cleaned_items.append(cleaned_item)
            self.stats["kept_samples"] += 1
        
        total_paraphrases_after = 0
        for item in cleaned_items:
            if "paraphrases" in item and item["paraphrases"]:
                valid_paraphrases = [p for p in item["paraphrases"] if p.get("status") == "valid"]
                total_paraphrases_after += len(valid_paraphrases)
        
        if len(cleaned_items) > 0:
            self.stats["avg_paraphrases_after"] = total_paraphrases_after / len(cleaned_items)
        
        if has_metadata:
            output_data = data.copy()
            output_data["data"] = cleaned_items
            
            if "metadata" in output_data:
                output_data["metadata"]["cleaning_stats"] = self.stats.copy()
                output_data["metadata"]["cleaned_at"] = pd.Timestamp.now().isoformat()
        else:
            output_data = cleaned_items
        
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(output_data, f, ensure_ascii=False, indent=2)
            print(f"✅ Data saved to: {output_file}")
        except Exception as e:
            print(f"❌ Failed to save file: {e}")
            return None
        
        self.print_stats()
        
        return output_data
    
    def print_stats(self):
        """Print cleaning statistics"""
        print("\n📊 Data Cleaning Statistics")
        print("=" * 40)
        print(f"📋 Total original samples: {self.stats['total_samples']}")
        print(f"🗑️ Removed empty paraphrases: {self.stats['empty_paraphrases']}")
        print(f"🗑️ Removed invalid paraphrases: {self.stats['no_valid_paraphrases']}")
        print(f"✅ Kept samples: {self.stats['kept_samples']}")
        print(f"📈 Retention rate: {self.stats['kept_samples']/self.stats['total_samples']*100:.1f}%")
        print(f"📊 Avg paraphrases (before cleaning): {self.stats['avg_paraphrases_before']:.1f}")
        print(f"📊 Avg paraphrases (after cleaning): {self.stats['avg_paraphrases_after']:.1f}")
    
    def get_original_text(self, item: Dict) -> str:
        """
        Get original response text, try multiple possible field names
        
        Args:
            item: Data item
            
        Returns:
            Original response text, empty string if not found
        """
        possible_fields = ["response", "answer", "content", "text", "reply", "output", "original_response"]
        
        for field in possible_fields:
            if field in item and item[field]:
                return str(item[field])
        
        return ""
    
    def get_paraphrase_text(self, paraphrase: Dict) -> str:
        """
        Get paraphrase text
        
        Args:
            paraphrase: Paraphrase item
            
        Returns:
            Paraphrase text, empty string if not found
        """
        possible_fields = ["text", "content", "paraphrase", "response", "answer"]
        
        for field in possible_fields:
            if field in paraphrase and paraphrase[field]:
                return str(paraphrase[field])
        
        return ""
    
    def analyze_data_quality(self, file_path: str, show_detail: bool = False):
        """Analyze data quality"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"❌ Failed to read file: {e}")
            return
        
        if isinstance(data, dict) and "data" in data:
            items = data["data"]
        elif isinstance(data, list):
            items = data
        else:
            print("❌ Unsupported data format")
            return
        
        print(f"🔍 Data Quality Analysis: {file_path}")
        print("=" * 80)
        
        total_samples = len(items)
        empty_paraphrases = 0
        no_valid_paraphrases = 0
        similarity_scores = []
        paraphrase_counts = []
        
        all_sample_stats = []
        all_length_ratios = []
        all_original_lengths = []
        all_paraphrase_lengths = []
        
        print(f"\n📏 Individual Response Length Analysis:")
        print("-" * 80)
        
        for idx, item in enumerate(items):
            paraphrases = item.get("paraphrases", [])
            
            if not paraphrases:
                empty_paraphrases += 1
                if show_detail:
                    print(f"Sample {idx+1}: ❌ No paraphrase data")
                continue
            
            valid_paraphrases = [p for p in paraphrases if p.get("status") == "valid"]
            
            if not valid_paraphrases:
                no_valid_paraphrases += 1
                if show_detail:
                    print(f"Sample {idx+1}: ❌ No valid paraphrases")
                continue
            
            paraphrase_counts.append(len(valid_paraphrases))
            
            original_text = self.get_original_text(item)
            original_length = len(original_text) if original_text else 0
            
            if original_length == 0:
                if show_detail:
                    print(f"Sample {idx+1}: ⚠️ Cannot get original text length")
                continue
            
            sample_paraphrase_lengths = []
            sample_length_ratios = []
            sample_similarities = []
            
            for p in valid_paraphrases:
                if "similarity" in p:
                    similarity_scores.append(p["similarity"])
                    sample_similarities.append(p["similarity"])
                
                paraphrase_text = self.get_paraphrase_text(p)
                paraphrase_length = len(paraphrase_text) if paraphrase_text else 0
                
                if paraphrase_length > 0:
                    sample_paraphrase_lengths.append(paraphrase_length)
                    ratio = paraphrase_length / original_length
                    sample_length_ratios.append(ratio)
                    
                    all_original_lengths.append(original_length)
                    all_paraphrase_lengths.append(paraphrase_length)
                    all_length_ratios.append(ratio)
            
            if sample_paraphrase_lengths:
                avg_paraphrase_length = sum(sample_paraphrase_lengths) / len(sample_paraphrase_lengths)
                avg_length_ratio = sum(sample_length_ratios) / len(sample_length_ratios)
                avg_similarity = sum(sample_similarities) / len(sample_similarities) if sample_similarities else 0
                
                sample_stat = {
                    'index': idx + 1,
                    'original_length': original_length,
                    'paraphrase_count': len(sample_paraphrase_lengths),
                    'avg_paraphrase_length': avg_paraphrase_length,
                    'avg_length_ratio': avg_length_ratio,
                    'avg_similarity': avg_similarity,
                    'length_change': avg_paraphrase_length - original_length
                }
                all_sample_stats.append(sample_stat)
                
                change_indicator = "📈" if avg_length_ratio > 1.1 else "📉" if avg_length_ratio < 0.9 else "📊"
                similarity_str = f", Similarity: {avg_similarity:.3f}" if avg_similarity > 0 else ""
                
                print(f"Sample {idx+1:3d}: {change_indicator} Original: {original_length:4d}chars → Paraphrase: {avg_paraphrase_length:6.1f}chars "
                      f"(Ratio: {avg_length_ratio:.2f}, Paraphrase count: {len(sample_paraphrase_lengths)}{similarity_str})")
                
                if show_detail and len(sample_paraphrase_lengths) > 1:
                    for i, (length, ratio) in enumerate(zip(sample_paraphrase_lengths, sample_length_ratios)):
                        print(f"     Paraphrase {i+1}: {length}chars (Ratio: {ratio:.2f})")
        
        print(f"\n📊 Basic Statistics:")
        print("=" * 50)
        print(f"📋 Total samples: {total_samples}")
        print(f"🗑️ Empty paraphrases count: {empty_paraphrases}")
        print(f"🗑️ No valid paraphrases count: {no_valid_paraphrases}")
        print(f"✅ Valid samples count: {len(all_sample_stats)}")
        
        if paraphrase_counts:
            print(f"\n📊 Paraphrase Count Statistics:")
            print(f"   - Average: {sum(paraphrase_counts)/len(paraphrase_counts):.1f}")
            print(f"   - Minimum: {min(paraphrase_counts)}")
            print(f"   - Maximum: {max(paraphrase_counts)}")
        
        if similarity_scores:
            print(f"\n📊 Similarity Statistics:")
            print(f"   - Average: {sum(similarity_scores)/len(similarity_scores):.3f}")
            print(f"   - Minimum: {min(similarity_scores):.3f}")
            print(f"   - Maximum: {max(similarity_scores):.3f}")
        
        if all_sample_stats:
            print(f"\n📏 Length Change Summary Statistics:")
            print("=" * 50)
            
            sample_original_lengths = [s['original_length'] for s in all_sample_stats]
            sample_avg_paraphrase_lengths = [s['avg_paraphrase_length'] for s in all_sample_stats]
            sample_avg_ratios = [s['avg_length_ratio'] for s in all_sample_stats]
            sample_length_changes = [s['length_change'] for s in all_sample_stats]
            
            print(f"📋 Per-sample Statistics (average paraphrase of each original response):")
            print(f"   Original response length:")
            print(f"     - Average: {sum(sample_original_lengths)/len(sample_original_lengths):.1f} chars")
            print(f"     - Minimum: {min(sample_original_lengths)} chars")
            print(f"     - Maximum: {max(sample_original_lengths)} chars")
            
            print(f"   Average paraphrase length:")
            print(f"     - Average: {sum(sample_avg_paraphrase_lengths)/len(sample_avg_paraphrase_lengths):.1f} chars")
            print(f"     - Minimum: {min(sample_avg_paraphrase_lengths):.1f} chars")
            print(f"     - Maximum: {max(sample_avg_paraphrase_lengths):.1f} chars")
            
            print(f"   Length change:")
            print(f"     - Average change: {sum(sample_length_changes)/len(sample_length_changes):+.1f} chars")
            print(f"     - Maximum decrease: {min(sample_length_changes):+.1f} chars")
            print(f"     - Maximum increase: {max(sample_length_changes):+.1f} chars")
            
            print(f"   Length ratio (paraphrase/original):")
            print(f"     - Average: {sum(sample_avg_ratios)/len(sample_avg_ratios):.2f}")
            print(f"     - Minimum: {min(sample_avg_ratios):.2f}")
            print(f"     - Maximum: {max(sample_avg_ratios):.2f}")
            
            ratio_bins = {
                "Significantly shorter (<0.7)": sum(1 for r in sample_avg_ratios if r < 0.7),
                "Slightly shorter (0.7-0.9)": sum(1 for r in sample_avg_ratios if 0.7 <= r < 0.9),
                "Nearly unchanged (0.9-1.1)": sum(1 for r in sample_avg_ratios if 0.9 <= r <= 1.1),
                "Slightly longer (1.1-1.3)": sum(1 for r in sample_avg_ratios if 1.1 < r <= 1.3),
                "Significantly longer (>1.3)": sum(1 for r in sample_avg_ratios if r > 1.3)
            }
            
            print(f"\n   Sample Length Change Distribution:")
            total_samples_with_data = len(sample_avg_ratios)
            for bin_name, count in ratio_bins.items():
                percentage = count / total_samples_with_data * 100
                print(f"     - {bin_name}: {count} ({percentage:.1f}%)")
            
            if all_length_ratios:
                print(f"\n📋 Per-paraphrase Statistics (overall status of all paraphrases):")
                print(f"   Total paraphrases count: {len(all_length_ratios)}")
                print(f"   Length ratio (paraphrase/original):")
                print(f"     - Average: {sum(all_length_ratios)/len(all_length_ratios):.2f}")
                print(f"     - Minimum: {min(all_length_ratios):.2f}")
                print(f"     - Maximum: {max(all_length_ratios):.2f}")
        else:
            print(f"\n⚠️ Cannot get length statistics (no suitable text fields found)")
            print(f"   Please check the field names of original responses and paraphrases in the data")
