#!/usr/bin/env python3

import argparse
import json
import os
import sys
import random
from typing import Any, Tuple, Union


def load_json(path: str) -> Any:
    if not os.path.exists(path):
        raise FileNotFoundError(f"文件不存在: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except json.JSONDecodeError as e:
        raise ValueError(f"JSON 解析失败: {path}: {e}")


def sample_data(data: Any, count: Union[str, int]) -> Any:
    """根据指定数量采样数据"""
    if count == "all":
        return data
    
    if not isinstance(count, int) or count <= 0:
        raise ValueError(f"无效的数量参数: {count}，应为正整数或 'all'")
    
    # 如果是数组，随机采样
    if isinstance(data, list):
        if count >= len(data):
            return data
        return random.sample(data, count)
    
    # 如果是对象，随机选择键
    elif isinstance(data, dict):
        keys = list(data.keys())
        if count >= len(keys):
            return data
        selected_keys = random.sample(keys, count)
        return {k: data[k] for k in selected_keys}
    
    # 其他类型直接返回
    return data


def concat_values(a: Any, b: Any) -> Any:
    # 两个数组 -> 直接拼接
    if isinstance(a, list) and isinstance(b, list):
        return a + b
    # 两个对象 -> 浅合并（后者覆盖同名键）
    if isinstance(a, dict) and isinstance(b, dict):
        merged = dict(a)
        merged.update(b)
        return merged
    # 其他类型或类型不一致 -> 包装为数组
    return [a, b]


def parse_args(argv: list) -> Tuple[str, str, str, bool, Union[str, int], Union[str, int]]:
    parser = argparse.ArgumentParser(
        description="拼接/合并两个 JSON 文件并输出为新的 JSON 文件。"
    )
    parser.add_argument("input1", help="第一个 JSON 文件路径")
    parser.add_argument("input2", help="第二个 JSON 文件路径")
    parser.add_argument(
        "-o",
        "--output",
        required=True,
        help="输出 JSON 文件路径（若目录不存在将尝试创建）",
    )
    parser.add_argument(
        "--compact",
        action="store_true",
        help="紧凑输出（默认美化缩进）",
    )
    parser.add_argument(
        "--count1",
        default="all",
        help="第一个 JSON 文件保留的数据条数，数字或 'all'（默认: all）",
    )
    parser.add_argument(
        "--count2", 
        default="all",
        help="第二个 JSON 文件保留的数据条数，数字或 'all'（默认: all）",
    )

    args = parser.parse_args(argv)
    
    # 解析数量参数
    def parse_count(value: str) -> Union[str, int]:
        if value.lower() == "all":
            return "all"
        try:
            count = int(value)
            if count <= 0:
                raise ValueError(f"数量必须为正整数: {value}")
            return count
        except ValueError:
            raise ValueError(f"无效的数量参数: {value}，应为正整数或 'all'")
    
    count1 = parse_count(args.count1)
    count2 = parse_count(args.count2)
    
    return args.input1, args.input2, args.output, args.compact, count1, count2


def ensure_parent_dir(path: str) -> None:
    parent = os.path.dirname(os.path.abspath(path))
    if parent and not os.path.exists(parent):
        os.makedirs(parent, exist_ok=True)


def main(argv: list) -> int:
    try:
        input1, input2, output, compact, count1, count2 = parse_args(argv)
        
        # 设置随机种子以确保可重现性
        random.seed(42)
        
        data1 = load_json(input1)
        data2 = load_json(input2)
        
        # 根据指定数量采样数据
        sampled_data1 = sample_data(data1, count1)
        sampled_data2 = sample_data(data2, count2)
        
        # 输出采样信息
        if count1 != "all":
            print(f"从 {input1} 随机采样了 {count1} 条数据")
        if count2 != "all":
            print(f"从 {input2} 随机采样了 {count2} 条数据")
        
        result = concat_values(sampled_data1, sampled_data2)

        ensure_parent_dir(output)
        with open(output, "w", encoding="utf-8") as f:
            if compact:
                json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
            else:
                json.dump(result, f, ensure_ascii=False, indent=2)
        print(f"已生成: {output}")
        return 0
    except Exception as e:
        print(f"错误: {e}", file=sys.stderr)
        return 1


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:])) 



### 