import json
import random
from pathlib import Path
from typing import List, Dict
import argparse
from collections import defaultdict
from ..core.utils import load_jsonl
from .labeler import PipelineType

class LabeledDataChecker:

    @staticmethod
    def load_labeled_data(file_path: str) -> List[Dict]:
        return load_jsonl(file_path)
    
    @staticmethod
    def analyze_distribution(data: List[Dict]) -> Dict:
        total = len(data)
        stats = {
            "total": total,
            "labels": defaultdict(int),
            "sources": defaultdict(int),
            "difficulties": defaultdict(int),
            "source_label": defaultdict(lambda: defaultdict(int)),
            "difficulty_label": defaultdict(lambda: defaultdict(int))
        }
        
        for item in data:
            label = item["label"]
            source = item["source"]
            difficulty = item["difficulty"]

            stats["labels"][label] += 1
            stats["sources"][source] += 1
            stats["difficulties"][difficulty] += 1

            stats["source_label"][source][label] += 1
            stats["difficulty_label"][difficulty][label] += 1
            
        return stats
    
    @staticmethod
    def print_distribution(stats: Dict):
        
        print("\n" + "="*80)
        print("Data distribution statistics")
        print("="*80)

        print(f"\nTotal sample num: {stats['total']}")

        print("\nlabel distribution:")
        for label in sorted(stats["labels"].keys()):
            count = stats["labels"][label]
            percentage = count / stats["total"] * 100
            label_name = PipelineType(label).name
            print(f"  {label_name}: {count} ({percentage:.2f}%)")

        print("\nsource distribution:")
        for source in sorted(stats["sources"].keys()):
            count = stats["sources"][source]
            percentage = count / stats["total"] * 100
            print(f"  {source}: {count} ({percentage:.2f}%)")

        print("\nDifficulty Distribution:")
        for difficulty in sorted(stats["difficulties"].keys()):
            count = stats["difficulties"][difficulty]
            percentage = count / stats["total"] * 100
            print(f"  {difficulty}: {count} ({percentage:.2f}%)")

        print("\nSource-label cross distribution:")
        for source in sorted(stats["source_label"].keys()):
            print(f"\n{source}:")
            source_total = sum(stats["source_label"][source].values())
            for label in sorted(stats["source_label"][source].keys()):
                count = stats["source_label"][source][label]
                percentage = count / source_total * 100
                label_name = PipelineType(label).name
                print(f"  {label_name}: {count} ({percentage:.2f}%)")

        print("\nDifficulty - Label cross distribution:")
        for difficulty in sorted(stats["difficulty_label"].keys()):
            print(f"\n{difficulty}:")
            difficulty_total = sum(stats["difficulty_label"][difficulty].values())
            for label in sorted(stats["difficulty_label"][difficulty].keys()):
                count = stats["difficulty_label"][difficulty][label]
                percentage = count / difficulty_total * 100
                label_name = PipelineType(label).name
                print(f"  {label_name}: {count} ({percentage:.2f}%)")
        
        print("\n" + "="*80)
    
    @staticmethod
    def sample_by_label(data: List[Dict], label: int, sample_size: int) -> List[Dict]:
        labeled_data = [item for item in data if item["label"] == label]
        if len(labeled_data) <= sample_size:
            return labeled_data
        return random.sample(labeled_data, sample_size)
    
    @staticmethod
    def print_samples(samples: List[Dict]):
        for i, sample in enumerate(samples, 1):
            print("\n" + "="*80)
            print(f"[Sample {i}/{len(samples)}] Question ID: {sample['question_id']}; Difficulty: {sample['difficulty']}; Pipeline Label: {sample['label']}")
            print(f"Gold SQL: {sample['gold_sql']}")
            print("="*80)

def main():
    parser = argparse.ArgumentParser(description='Check labeled data distribution and sample examples')
    parser.add_argument('file_path', type=str, help='Path to the labeled data file')
    parser.add_argument('--label', type=int, choices=[1, 2, 3, 4], 
                      help='Label to sample (1: BASIC, 2: INTERMEDIATE, 3: ADVANCED, 4: UNSOLVED)')
    parser.add_argument('--sample_size', type=int, default=5, help='Number of samples to show')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--distribution_only', action='store_true', 
                      help='Only show distribution statistics without samples')
    parser.add_argument('--sample_only', action='store_true',
                      help='Only show samples without distribution statistics')
    
    args = parser.parse_args()
    
    if args.sample_only and args.distribution_only:
        parser.error("Cannot specify both --sample_only and --distribution_only")
    
    random.seed(args.seed)
    
    checker = LabeledDataChecker()
    data = checker.load_labeled_data(args.file_path)
    
    if args.sample_only:
        if args.label is None:
            parser.error("Must specify --label when using --sample_only")
            
        label_count = sum(1 for item in data if item["label"] == args.label)
        label_name = PipelineType(args.label).name
        
        print(f"\nSampling {args.sample_size} examples from {label_name} (Total: {label_count})")
        
        samples = checker.sample_by_label(data, args.label, args.sample_size)
        checker.print_samples(samples)
        
        print(f"\nShowed {len(samples)} samples out of {label_count} {label_name} examples")
    else:
        stats = checker.analyze_distribution(data)
        checker.print_distribution(stats)
        
        if not args.distribution_only and args.label is not None:
            label_name = PipelineType(args.label).name
            total_count = stats["labels"][args.label]
            
            print(f"\nSampling {args.sample_size} examples from {label_name} (Total: {total_count})")
            
            samples = checker.sample_by_label(data, args.label, args.sample_size)
            checker.print_samples(samples)
            
            print(f"\nShowed {len(samples)} samples out of {total_count} {label_name} examples")

if __name__ == "__main__":
    main() 