#!/usr/bin/env python3

import argparse
import json
import random
from pathlib import Path
from typing import Any, Dict, List


class TrapDataSampler:
    
    def __init__(self, input_files: List[str], output_dir: str, test_size: int = 700, train_size: int = 2000, random_seed: int = 42):
        self.input_files = [Path(f) for f in input_files]
        self.output_dir = Path(output_dir)
        self.test_size = test_size
        self.train_size = train_size
        self.random_seed = random_seed
        
        self.output_dir.mkdir(parents=True, exist_ok=True)
    
    def load_json_file(self, file_path: Path) -> List[Dict[str, Any]]:
        if not file_path.exists():
            print(f"Warning: File does not exist: {file_path}")
            return []
        
        print(f"Loading: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            print(f"  Loaded {len(data)} samples")
            return data
        except Exception as e:
            print(f"Error loading file {file_path}: {e}")
            return []
    
    def filter_click_samples(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Filter samples where gt_action action is 'click'."""
        click_samples = []
        for item in data:
            gt_action = item.get('gt_action', {})
            if isinstance(gt_action, dict) and gt_action.get('action') == 'click':
                click_samples.append(item)
        
        if len(click_samples) == 0:
            print("Warning: No click action samples found. Checking all action types...")
            action_types = {}
            for item in data:
                gt_action = item.get('gt_action', {})
                if isinstance(gt_action, dict):
                    action = gt_action.get('action', 'unknown')
                    action_types[action] = action_types.get(action, 0) + 1
            print(f"Available action types: {action_types}")
        
        return click_samples
    
    def sample_data(self, all_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
        random.seed(self.random_seed)
        
        shuffled_data = all_data.copy()
        random.shuffle(shuffled_data)
        
        total_needed = self.test_size + self.train_size
        if len(shuffled_data) < total_needed:
            print(f"Warning: Total data ({len(shuffled_data)}) is less than required ({total_needed})")
            print(f"Will use all available data")
            if len(shuffled_data) > 0:
                test_ratio = self.test_size / total_needed
                actual_test_size = int(len(shuffled_data) * test_ratio)
                test_data = shuffled_data[:actual_test_size]
                train_data = shuffled_data[actual_test_size:]
            else:
                test_data = []
                train_data = []
        else:
            test_data = shuffled_data[:self.test_size]
            train_data = shuffled_data[self.test_size:self.test_size + self.train_size]
        
        return test_data, train_data
    
    def save_json_file(self, data: List[Dict[str, Any]], file_path: Path) -> None:
        file_path.parent.mkdir(parents=True, exist_ok=True)
        
        print(f"Saving: {file_path}")
        try:
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
            print(f"  Saved {len(data)} samples")
        except Exception as e:
            print(f"Error saving file {file_path}: {e}")
    
    def run(self):
        print("=" * 60)
        print("Starting trap data sampling (click actions only)...")
        print("=" * 60)
        
        all_data = []
        for file_path in self.input_files:
            data = self.load_json_file(file_path)
            all_data.extend(data)
        
        print(f"\nTotal loaded: {len(all_data)} samples")
        
        click_samples = self.filter_click_samples(all_data)
        print(f"Filtered to click actions: {len(click_samples)} samples")
        
        if len(click_samples) == 0:
            print("Error: No click action samples found!")
            return
        
        print("\nStarting random sampling...")
        test_data, train_data = self.sample_data(click_samples)
        
        print(f"Sampling results:")
        print(f"  Test set: {len(test_data)} samples")
        print(f"  Train set: {len(train_data)} samples")
        
        print("\nSaving results...")
        test_output = self.output_dir / "test_trap_click.json"
        train_output = self.output_dir / "train_trap_click.json"
        
        self.save_json_file(test_data, test_output)
        self.save_json_file(train_data, train_output)
        
        print("\n" + "=" * 60)
        print("Processing completed!")
        print("=" * 60)
        print(f"Test set saved to: {test_output}")
        print(f"Train set saved to: {train_output}")


def main():
    parser = argparse.ArgumentParser(description='Sample data with click actions for trap data generation')
    parser.add_argument('--input_files', type=str, nargs='+', default=['/INPUT_FILE_1', '/INPUT_FILE_2'],
                       help='Input JSON file paths')
    parser.add_argument('--output_dir', type=str, default='/OUTPUT_DIR',
                       help='Output directory for sampled data')
    parser.add_argument('--test_size', type=int, default=700,
                       help='Test set size')
    parser.add_argument('--train_size', type=int, default=2000,
                       help='Train set size')
    parser.add_argument('--random_seed', type=int, default=42,
                       help='Random seed')
    
    args = parser.parse_args()
    
    sampler = TrapDataSampler(
        input_files=args.input_files,
        output_dir=args.output_dir,
        test_size=args.test_size,
        train_size=args.train_size,
        random_seed=args.random_seed
    )
    
    sampler.run()


if __name__ == "__main__":
    main()
