import json
import argparse
from collections import defaultdict, Counter
import sys
import random
import numpy as np

try:
    import pandas as pd
except ImportError:
    print("Error: This script requires 'pandas' and 'pyarrow' libraries to read Parquet files.", file=sys.stderr)
    print("Please install them using 'pip install pandas pyarrow'.", file=sys.stderr)
    sys.exit(1)

def load_jsonl(filepath):
    """Load data from a JSONL file."""
    data = []
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    pass
    except FileNotFoundError:
        print(f"Error: File not found {filepath}", file=sys.stderr)
        sys.exit(1)
    return data

def save_jsonl(data, filepath):
    """Save data to a JSONL file."""
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
    except Exception as e:
        print(f"Error writing to file {filepath}: {e}", file=sys.stderr)
        sys.exit(1)

def load_parquet(filepath):
    """Load data from a Parquet file."""
    data = []
    try:
        df = pd.read_parquet(filepath) 
        data = df.to_dict('records')
    except Exception as e:
        print(f"Error reading Parquet file {filepath}: {e}", file=sys.stderr)
        sys.exit(1)
    return data

def main():
    parser = argparse.ArgumentParser(description="Data Filtering Script: Supports Influence Score, Average Entropy, and Random modes.")
    
    parser.add_argument('--base_file', type=str, required=True, help='JSONL output from Base model (containing item_loss or token_entropy)')
    parser.add_argument('--ref_file', type=str, required=False, help='Output from Ref model (Required for Score mode only)')
    parser.add_argument('--stats_file', type=str, required=True, help='Parquet file for aligning cot_type distribution')
    parser.add_argument('--output_file', type=str, required=True, help='Output path')
    
    parser.add_argument('--random_mode', action='store_true', help='[Mode 1] Random Mode: Random sampling')
    parser.add_argument('--entropy_mode', action='store_true', help='[Mode 2] Entropy Mode: Filter by Average Token Entropy (Hard Sample)')

    args = parser.parse_args()
    if args.random_mode and args.entropy_mode:
        parser.error("Error: Cannot enable both --random_mode and --entropy_mode.")

    is_score_mode = not args.random_mode and not args.entropy_mode
    if is_score_mode and not args.ref_file:
        parser.error("Error: --ref_file must be provided in Influence Score Mode (default).")

    if args.random_mode:
        print(">>> Current Mode: Random Mode <<<")
    elif args.entropy_mode:
        print(">>> Current Mode: Entropy Mode <<<")
        print("    Note: Selecting samples with highest average entropy (Top-N).")
    else:
        print(">>> Current Mode: Influence Score Mode <<<")

    print(f"Reading Base file: {args.base_file}")
    base_data = load_jsonl(args.base_file)
    
    print(f"Reading Stats file: {args.stats_file}")
    stats_data = load_parquet(args.stats_file)
    
    ref_data = []
    if is_score_mode:
        print(f"Reading Ref file: {args.ref_file}")
        ref_data = load_jsonl(args.ref_file)

    if not base_data:
        print("Error: Base file is empty.", file=sys.stderr)
        sys.exit(1)

    print(f"Calculating target distribution...")
    cot_type_counts = Counter(item.get('cot_type') for item in stats_data if 'cot_type' in item)
    print(f"Total {len(cot_type_counts)} cot_types found.")

    processed_candidates = []

    if args.random_mode:
        print("Random Mode: Skipping calculation, preparing data...")
        for item in base_data:
            if 'cot_type' in item:
                item['attention_influence_score'] = random.random()
                processed_candidates.append(item)
    elif args.entropy_mode:
        print("Entropy Mode: Calculating average Token Entropy...")
        skipped_no_entropy = 0
        
        for item in base_data:
            if 'cot_type' not in item: continue

            entropies = item.get('token_entropy')
            
            if entropies and isinstance(entropies, list) and len(entropies) > 0:
                avg_entropy = float(np.mean(entropies))
                
                item['attention_influence_score'] = avg_entropy

                processed_candidates.append(item)
            else:
                skipped_no_entropy += 1
        
        print(f"Entropy calculation complete. Valid samples: {len(processed_candidates)}")
        if skipped_no_entropy > 0:
            print(f"Warning: {skipped_no_entropy} samples skipped due to missing 'token_entropy' field.")

    # === Branch C: Influence Score Mode (Default) ===
    else:
        print("Normal Mode: Calculating AttentionInfluence scores...")
        ref_map = {item['text']: item for item in ref_data if 'text' in item}
        
        for base_item in base_data:
            text = base_item.get('text')
            if not text or text not in ref_map: continue
            
            ref_item = ref_map[text]
            base_loss = base_item.get('item_loss', 0)
            ref_loss = ref_item.get('item_loss', 0)
            
            if base_loss != 0:
                score = (ref_loss - base_loss) / base_loss

                new_entry = base_item.copy()
                new_entry['attention_influence_score'] = score
                new_entry['base_loss'] = base_loss
                new_entry['ref_loss'] = ref_loss
                
                # Clean up heavy fields if necessary
                if 'token_loss' in new_entry: del new_entry['token_loss']
                if 'token_entropy' in new_entry: del new_entry['token_entropy']
                
                processed_candidates.append(new_entry)
    print("Filtering by group based on cot_type...")
    
    grouped_items = defaultdict(list)
    for item in processed_candidates:
        if item['cot_type'] in cot_type_counts:
            grouped_items[item['cot_type']].append(item)

    final_selection = []
    
    for cot_type, items in grouped_items.items():
        target_k = cot_type_counts[cot_type]
        
        if args.random_mode:
            if len(items) > target_k:
                selected_batch = random.sample(items, target_k)
            else:
                selected_batch = items
        else:
            sorted_items = sorted(items, key=lambda x: x['attention_influence_score'], reverse=True)
            selected_batch = sorted_items[:target_k]
            
        final_selection.extend(selected_batch)


    print(f"Filtering complete: Selected {len(final_selection)} items.")
    print(f"Saving to: {args.output_file}")
    save_jsonl(final_selection, args.output_file)
    print("Done.")

if __name__ == '__main__':
    main()