#!/usr/bin/env python3
"""
Monthly CVE Diversity Sampling Script

Usage:
    # Sample from July to November 2025
    python run_monthly_sampling.py --months 2025-07 2025-11

    # Sample specific months
    python run_monthly_sampling.py --months 2025-07 2025-08 2025-09

    # Custom parameters
    python run_monthly_sampling.py --months 2025-07 2025-11 --count 50 --min-score 60

    # Exclude existing CVEs
    python run_monthly_sampling.py --months 2025-07 2025-11 --exclude-dir /path/to/existing
"""

import argparse
import json
import os
import glob
from diverse_cve_sampler import DiverseCVESampler


def parse_months(months_args: list) -> list:
    """
    Parse month arguments.

    Supports two formats:
    - Range: ['2025-07', '2025-11'] -> ['2025-07', '2025-08', '2025-09', '2025-10', '2025-11']
    - List: ['2025-07', '2025-08', '2025-09'] -> returns as-is
    """
    if len(months_args) == 2:
        start, end = months_args
        start_year, start_month = int(start[:4]), int(start[5:7])
        end_year, end_month = int(end[:4]), int(end[5:7])

        # Check if it's a range (end > start)
        if (end_year > start_year) or (end_year == start_year and end_month > start_month):
            months = []
            year, month = start_year, start_month
            while (year < end_year) or (year == end_year and month <= end_month):
                months.append(f"{year}-{month:02d}")
                month += 1
                if month > 12:
                    month = 1
                    year += 1
            return months

    return months_args


def get_existing_cve_ids(directory: str) -> set:
    """Get existing CVE IDs from directory."""
    if not directory or not os.path.exists(directory):
        return set()

    files = glob.glob(os.path.join(directory, 'CVE-*.md'))
    return set(os.path.basename(f).replace('.md', '') for f in files)


def main():
    parser = argparse.ArgumentParser(
        description='Monthly CVE Diversity Sampling Tool',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Sample from July to November 2025 (range)
  python run_monthly_sampling.py --months 2025-07 2025-11

  # Sample specific months
  python run_monthly_sampling.py --months 2025-07 2025-08 2025-09

  # 50 CVEs per month, minimum score 60
  python run_monthly_sampling.py --months 2025-07 2025-11 --count 50 --min-score 60

  # Exclude existing CVEs
  python run_monthly_sampling.py --months 2025-07 2025-11 --exclude-dir ./existing_cves
        """
    )

    parser.add_argument('--months', nargs='+', required=True,
                        help='Month range or list, e.g., "2025-07 2025-11" or "2025-07 2025-08 2025-09"')
    parser.add_argument('--count', type=int, default=100,
                        help='Number of CVEs per month (default: 100)')
    parser.add_argument('--min-score', type=int, default=50,
                        help='Minimum reproducibility score (default: 50)')
    parser.add_argument('--top25-per-cwe', type=int, default=2,
                        help='Guaranteed CVEs per Top25 CWE (default: 2)')
    parser.add_argument('--max-per-cwe', type=int, default=10,
                        help='Maximum CVEs per CWE type (default: 10)')
    parser.add_argument('--max-per-repo', type=int, default=10,
                        help='Maximum CVEs per repository (default: 10)')
    parser.add_argument('--exclude-dir', type=str, default=None,
                        help='Directory containing existing CVEs to exclude')
    parser.add_argument('--output-dir', type=str, default='monthly_samples',
                        help='Output directory (default: monthly_samples)')
    parser.add_argument('--summary', type=str, default='output/summary.json',
                        help='Path to summary.json (default: output/summary.json)')

    args = parser.parse_args()

    # Parse months
    months = parse_months(args.months)
    year = months[0][:4]

    print("=" * 60)
    print("Monthly CVE Sampling Script")
    print("=" * 60)

    # 1. Get existing CVEs
    existing_ids = get_existing_cve_ids(args.exclude_dir)
    if existing_ids:
        print(f"\nExisting CVEs: {len(existing_ids)} (from {args.exclude_dir})")

    # 2. Load sampler
    print(f"\nLoading data: {args.summary}")
    sampler = DiverseCVESampler(summary_path=args.summary)
    original_count = len(sampler.all_cves)

    # 3. Filter existing CVEs
    if existing_ids:
        sampler.all_cves = [c for c in sampler.all_cves if c['cve_id'] not in existing_ids]
        filtered_count = original_count - len(sampler.all_cves)
        print(f"After filtering: {len(sampler.all_cves)} (excluded {filtered_count})")

    # 4. Print sampling parameters
    print(f"\nSampling parameters:")
    print(f"  - Months: {months}")
    print(f"  - Target per month: {args.count}")
    print(f"  - Minimum score: {args.min_score}")
    print(f"  - Top25 guarantee: {args.top25_per_cwe} per CWE")
    print(f"  - Max per CWE: {args.max_per_cwe}")
    print(f"  - Max per repo: {args.max_per_repo}")

    # 5. Execute sampling
    results = sampler.sample_by_month(
        year=year,
        target_per_month=args.count,
        months=months,
        min_score=args.min_score,
        top25_per_cwe=args.top25_per_cwe,
        max_per_cwe=args.max_per_cwe,
        max_per_repo=args.max_per_repo
    )

    # 6. Save results
    os.makedirs(args.output_dir, exist_ok=True)

    print("\n" + "=" * 60)
    print("Saving Results")
    print("=" * 60)

    all_cve_ids = []

    for month, cves in sorted(results.items()):
        # Save JSON
        json_path = os.path.join(args.output_dir, f'{month}_top{args.count}.json')
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(cves, f, indent=2, ensure_ascii=False)

        # Statistics
        scores = [c['score'] for c in cves]
        vendors = len(set(c.get('vendor', '') for c in cves))
        cwes = len(set(c.get('cwe_id', '') for c in cves))
        top25 = len(set(c['cwe_id'] for c in cves if c['cve_id'] in sampler.CWE_TOP25))

        print(f"{month}: {len(cves)} CVE | {vendors} vendors | {cwes} CWEs | "
              f"score {min(scores)}-{max(scores)} | Top25: {top25}/25")

        # Collect CVE IDs
        all_cve_ids.extend([c['cve_id'] for c in cves])

    # 7. Save CVE ID list
    ids_path = os.path.join(args.output_dir, 'all_cve_ids.txt')
    with open(ids_path, 'w') as f:
        f.write(' '.join(all_cve_ids))

    print(f"\nTotal: {len(all_cve_ids)} CVEs")
    print(f"CVE ID list: {ids_path}")
    print(f"JSON files: {args.output_dir}/")


if __name__ == "__main__":
    main()
