
import argparse
import json
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Tuple


class GeneralDataProcessor:
    
    def __init__(self, input_files: List[str], output_dir: str, test_size: int = 1000, train_size: int = 4000, random_seed: int = 42):
        self.input_files = [Path(f) for f in input_files]
        self.output_dir = Path(output_dir)
        self.test_size = test_size
        self.train_size = train_size
        self.random_seed = random_seed
        
        self.output_dir.mkdir(parents=True, exist_ok=True)
    
    def load_json_file(self, file_path: Path) -> List[Dict[str, Any]]:
        if not file_path.exists():
            print(f"Warning: File does not exist: {file_path}")
            return []
        
        print(f"Loading: {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            print(f"  Loaded {len(data)} samples")
            return data
        except Exception as e:
            print(f"Error loading file {file_path}: {e}")
            return []
    
    def normalize_episode_id(self, item: Dict[str, Any]) -> Dict[str, Any]:
        if "episode_id" in item and not isinstance(item["episode_id"], str):
            item["episode_id"] = str(item["episode_id"])
        return item
    
    def sample_data(self, all_data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
        random.seed(self.random_seed)
        
        shuffled_data = all_data.copy()
        random.shuffle(shuffled_data)
        
        total_needed = self.test_size + self.train_size
        if len(shuffled_data) < total_needed:
            print(f"Warning: Total data ({len(shuffled_data)}) is less than required ({total_needed})")
            print(f"Will use all available data")
            if len(shuffled_data) > 0:
                test_ratio = self.test_size / total_needed
                actual_test_size = int(len(shuffled_data) * test_ratio)
                test_data = shuffled_data[:actual_test_size]
                train_data = shuffled_data[actual_test_size:]
            else:
                test_data = []
                train_data = []
        else:
            test_data = shuffled_data[:self.test_size]
            train_data = shuffled_data[self.test_size:self.test_size + self.train_size]
        
        return test_data, train_data
    
    def save_json_file(self, data: List[Dict[str, Any]], file_path: Path) -> None:
        file_path.parent.mkdir(parents=True, exist_ok=True)
        
        print(f"Saving: {file_path}")
        try:
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
            print(f"  Saved {len(data)} samples")
        except Exception as e:
            print(f"Error saving file {file_path}: {e}")
    
    def run(self):
        print("=" * 60)
        print("Starting data sampling...")
        print("=" * 60)
        
        all_data = []
        for file_path in self.input_files:
            data = self.load_json_file(file_path)
            normalized_data = [self.normalize_episode_id(item) for item in data]
            all_data.extend(normalized_data)
        
        print(f"\nTotal loaded: {len(all_data)} samples")
        
        original_count = len(all_data)
        all_data = [item for item in all_data if item.get('gt_action') is not None]
        filtered_count = len(all_data)
        print(f"Filtered {original_count - filtered_count} samples with null gt_action, remaining: {filtered_count}")
        
        print("\nStarting random sampling...")
        test_data, train_data = self.sample_data(all_data)
        
        print(f"Sampling results:")
        print(f"  Test set: {len(test_data)} samples")
        print(f"  Train set: {len(train_data)} samples")
        
        print("\nSaving results...")
        test_output = self.output_dir / "test_general.json"
        train_output = self.output_dir / "train_general.json"
        
        self.save_json_file(test_data, test_output)
        self.save_json_file(train_data, train_output)
        
        print("\n" + "=" * 60)
        print("Processing completed!")
        print("=" * 60)
        print(f"Test set saved to: {test_output}")
        print(f"Train set saved to: {train_output}")


def main():
    parser = argparse.ArgumentParser(description='Sample data from multiple datasets')
    parser.add_argument('--input_files', type=str, nargs='+', default=['/INPUT_FILE_1', '/INPUT_FILE_2', '/INPUT_FILE_3', '/INPUT_FILE_4'],
                       help='Input JSON file paths')
    parser.add_argument('--output_dir', type=str, default='/OUTPUT_DIR',
                       help='Output directory for processed data')
    parser.add_argument('--test_size', type=int, default=1000,
                       help='Test set size')
    parser.add_argument('--train_size', type=int, default=4000,
                       help='Train set size')
    parser.add_argument('--random_seed', type=int, default=42,
                       help='Random seed')
    
    args = parser.parse_args()
    
    processor = GeneralDataProcessor(
        input_files=args.input_files,
        output_dir=args.output_dir,
        test_size=args.test_size,
        train_size=args.train_size,
        random_seed=args.random_seed
    )
    
    processor.run()


if __name__ == "__main__":
    main()
