#!/usr/bin/env python3
"""
Diverse CVE Sampler v2

Goal: Balance reproducibility, importance, and diversity in CVE sampling.

Core Algorithm: Two-Phase Sampling
1. Phase 1: Guarantee Top 25 CWE coverage (2 CVEs each)
2. Phase 2: Fill remaining slots using composite scoring
"""

import json
import os
from collections import Counter, defaultdict
from typing import List, Dict, Optional, Set, Tuple


class DiverseCVESampler:
    """Diversity-aware CVE Sampler"""

    # ==================== 2024 CWE Top 25 ====================
    # Source: https://cwe.mitre.org/top25/archive/2024/2024_cwe_top25.html
    # Ranked by danger score (higher = more dangerous)

    CWE_TOP25 = {
        'CWE-79':  {'rank': 1,  'name': 'Cross-site Scripting (XSS)', 'score': 56.92},
        'CWE-787': {'rank': 2,  'name': 'Out-of-bounds Write', 'score': 45.20},
        'CWE-89':  {'rank': 3,  'name': 'SQL Injection', 'score': 35.88},
        'CWE-352': {'rank': 4,  'name': 'Cross-Site Request Forgery', 'score': 19.57},
        'CWE-22':  {'rank': 5,  'name': 'Path Traversal', 'score': 12.74},
        'CWE-125': {'rank': 6,  'name': 'Out-of-bounds Read', 'score': 11.42},
        'CWE-78':  {'rank': 7,  'name': 'OS Command Injection', 'score': 11.30},
        'CWE-416': {'rank': 8,  'name': 'Use After Free', 'score': 10.19},
        'CWE-862': {'rank': 9,  'name': 'Missing Authorization', 'score': 10.11},
        'CWE-434': {'rank': 10, 'name': 'Unrestricted Upload', 'score': 10.03},
        'CWE-94':  {'rank': 11, 'name': 'Code Injection', 'score': 7.13},
        'CWE-20':  {'rank': 12, 'name': 'Improper Input Validation', 'score': 6.78},
        'CWE-77':  {'rank': 13, 'name': 'Command Injection', 'score': 6.74},
        'CWE-287': {'rank': 14, 'name': 'Improper Authentication', 'score': 5.94},
        'CWE-269': {'rank': 15, 'name': 'Improper Privilege Management', 'score': 5.22},
        'CWE-502': {'rank': 16, 'name': 'Deserialization of Untrusted Data', 'score': 5.07},
        'CWE-200': {'rank': 17, 'name': 'Information Exposure', 'score': 5.07},
        'CWE-863': {'rank': 18, 'name': 'Incorrect Authorization', 'score': 4.05},
        'CWE-918': {'rank': 19, 'name': 'Server-Side Request Forgery', 'score': 4.05},
        'CWE-119': {'rank': 20, 'name': 'Buffer Overflow', 'score': 3.69},
        'CWE-476': {'rank': 21, 'name': 'NULL Pointer Dereference', 'score': 3.58},
        'CWE-798': {'rank': 22, 'name': 'Hard-coded Credentials', 'score': 3.46},
        'CWE-190': {'rank': 23, 'name': 'Integer Overflow', 'score': 3.37},
        'CWE-400': {'rank': 24, 'name': 'Uncontrolled Resource Consumption', 'score': 3.23},
        'CWE-306': {'rank': 25, 'name': 'Missing Authentication', 'score': 2.73},
    }

    # ==================== CWE Category Mapping ====================
    # Groups similar CWEs to prevent over-representation
    # Includes Top 25 and related CWEs

    CWE_CATEGORY_MAP = {
        # XSS (Top25 #1)
        'CWE-79': 'xss',
        'CWE-80': 'xss',   # Basic XSS

        # Memory Write (Top25 #2)
        'CWE-787': 'memory_write',
        'CWE-121': 'memory_write',  # Stack-based Buffer Overflow
        'CWE-122': 'memory_write',  # Heap-based Buffer Overflow

        # SQL Injection (Top25 #3)
        'CWE-89': 'sqli',
        'CWE-564': 'sqli',  # Hibernate Injection

        # CSRF (Top25 #4)
        'CWE-352': 'csrf',

        # Path Traversal (Top25 #5)
        'CWE-22': 'path_traversal',
        'CWE-23': 'path_traversal',
        'CWE-36': 'path_traversal',
        'CWE-35': 'path_traversal',
        'CWE-73': 'path_traversal',  # External Control of File Name

        # Memory Read (Top25 #6)
        'CWE-125': 'memory_read',

        # OS Command Injection (Top25 #7)
        'CWE-78': 'os_command',

        # Use After Free (Top25 #8)
        'CWE-416': 'use_after_free',
        'CWE-415': 'use_after_free',  # Double Free

        # Missing Authorization (Top25 #9)
        'CWE-862': 'missing_authz',

        # File Upload (Top25 #10)
        'CWE-434': 'file_upload',

        # Code Injection (Top25 #11)
        'CWE-94': 'code_injection',
        'CWE-95': 'code_injection',   # Eval Injection
        'CWE-917': 'code_injection',  # Expression Language Injection
        'CWE-1321': 'code_injection', # Prototype Pollution

        # Input Validation (Top25 #12)
        'CWE-20': 'input_validation',

        # Command Injection (Top25 #13)
        'CWE-77': 'command_injection',

        # Authentication (Top25 #14)
        'CWE-287': 'authentication',
        'CWE-288': 'authentication',  # Authentication Bypass

        # Privilege Management (Top25 #15)
        'CWE-269': 'privilege_mgmt',
        'CWE-266': 'privilege_mgmt',  # Incorrect Privilege Assignment
        'CWE-250': 'privilege_mgmt',  # Execution with Unnecessary Privileges

        # Deserialization (Top25 #16)
        'CWE-502': 'deserialization',

        # Information Exposure (Top25 #17)
        'CWE-200': 'info_exposure',
        'CWE-209': 'info_exposure',  # Error Message Info Leak
        'CWE-532': 'info_exposure',  # Log File Info Leak
        'CWE-497': 'info_exposure',  # Exposure of System Data
        'CWE-201': 'info_exposure',  # Insertion of Sensitive Information

        # Incorrect Authorization (Top25 #18)
        'CWE-863': 'incorrect_authz',
        'CWE-639': 'incorrect_authz',  # IDOR

        # SSRF (Top25 #19)
        'CWE-918': 'ssrf',

        # Buffer Operations (Top25 #20)
        'CWE-119': 'buffer_ops',
        'CWE-120': 'buffer_ops',  # Buffer Copy without Size Check

        # NULL Pointer (Top25 #21)
        'CWE-476': 'null_pointer',

        # Hardcoded Credentials (Top25 #22)
        'CWE-798': 'hardcoded_creds',
        'CWE-321': 'hardcoded_creds',  # Hard-coded Cryptographic Key
        'CWE-522': 'hardcoded_creds',  # Insufficiently Protected Credentials

        # Integer Overflow (Top25 #23)
        'CWE-190': 'integer_overflow',
        'CWE-191': 'integer_overflow',  # Integer Underflow

        # Resource Consumption (Top25 #24)
        'CWE-400': 'resource_consumption',
        'CWE-770': 'resource_consumption',  # Allocation without Limits
        'CWE-1333': 'resource_consumption', # ReDoS
        'CWE-401': 'resource_consumption',  # Memory Leak

        # Missing Authentication (Top25 #25)
        'CWE-306': 'missing_authn',

        # ==================== Other Common CWEs ====================

        # XXE
        'CWE-611': 'xxe',

        # File Inclusion
        'CWE-98': 'file_inclusion',

        # Open Redirect
        'CWE-601': 'open_redirect',

        # Access Control
        'CWE-284': 'access_control',

        # Race Condition
        'CWE-362': 'race_condition',

        # Permission Issues
        'CWE-276': 'permission',
        'CWE-732': 'permission',

        # Certificate Validation
        'CWE-295': 'cert_validation',

        # Brute Force
        'CWE-307': 'brute_force',

        # Spoofing
        'CWE-290': 'spoofing',
    }

    def __init__(self, summary_path: str = 'reproduce_cves_score0/summary.json'):
        """
        Initialize the sampler.

        Args:
            summary_path: Path to summary.json
        """
        with open(summary_path, 'r') as f:
            self.summary = json.load(f)
        self.all_cves = self.summary['cves']
        print(f"Loaded {len(self.all_cves)} CVEs")

    def get_cwe_category(self, cwe_id: str) -> str:
        """
        Get CWE category.

        - Mapped CWEs: Return merged category (e.g., CWE-121 -> 'memory_write')
        - Unmapped CWEs: Return the CWE ID itself to preserve diversity
        """
        if not cwe_id:
            return 'unknown'
        # Unmapped CWEs keep their ID to avoid losing diversity in an 'other' bucket
        return self.CWE_CATEGORY_MAP.get(cwe_id, cwe_id)

    def is_top25_cwe(self, cwe_id: str) -> bool:
        """Check if CWE is in Top 25"""
        return cwe_id in self.CWE_TOP25

    def get_cwe_importance(self, cwe_id: str) -> float:
        """
        Get CWE importance score.

        Top 25: Returns danger score (2.73 - 56.92)
        Non-Top 25: Returns 1.0
        """
        if cwe_id in self.CWE_TOP25:
            return self.CWE_TOP25[cwe_id]['score']
        return 1.0

    def get_repo_key(self, cve: dict) -> Tuple[str, str]:
        """Extract repo identifier (vendor, product)"""
        vendor = cve.get('vendor', 'unknown').lower().strip()
        product = cve.get('product', 'unknown').lower().strip()
        return (vendor, product)

    def smart_sample(
        self,
        cves: List[dict],
        target_count: int = 100,
        min_score: int = 30,
        top25_per_cwe: int = 2,
        max_per_cwe: int = 5,
        max_per_repo: int = 2,
        verbose: bool = True
    ) -> List[dict]:
        """
        Smart sampling algorithm - core method.

        Two-phase sampling:
        1. Phase 1: Guarantee Top 25 CWE coverage (top25_per_cwe each)
        2. Phase 2: Fill remaining slots using composite scoring

        Composite Score = Base Score + Importance Bonus + CVSS Bonus + Diversity Bonus + Novelty Bonus

        Args:
            cves: Candidate CVE list
            target_count: Target sample count
            min_score: Minimum reproducibility score threshold
            top25_per_cwe: Guaranteed CVEs per Top 25 CWE
            max_per_cwe: Maximum CVEs per CWE type (soft limit)
            max_per_repo: Maximum CVEs per repo
            verbose: Print detailed information

        Returns:
            List of sampled CVEs
        """
        # Filter low-score CVEs
        candidates = [c for c in cves if c.get('score', 0) >= min_score]

        if verbose:
            print(f"\nCandidate CVEs: {len(candidates)} (score >= {min_score})")

        # ==================== Phase 1: Top 25 CWE Guarantee ====================
        selected: List[dict] = []
        selected_ids: Set[str] = set()
        selected_cwe_counts: Counter = Counter()
        selected_repos: Set[Tuple[str, str]] = set()

        if verbose:
            print(f"\n=== Phase 1: Top 25 CWE Guarantee ({top25_per_cwe} each) ===")

        for cwe_id in self.CWE_TOP25.keys():
            # Get all CVEs for this CWE, sorted by score
            cwe_cves = [c for c in candidates if c.get('cwe_id') == cwe_id]
            cwe_cves.sort(key=lambda x: x.get('score', 0), reverse=True)

            # Deduplicate by repo within same CWE, take top N
            count = 0
            seen_repos_in_cwe: Set[Tuple[str, str]] = set()

            for cve in cwe_cves:
                if count >= top25_per_cwe:
                    break

                repo = self.get_repo_key(cve)
                cve_id_str = cve['cve_id']

                # Repo deduplication within CWE
                if repo in seen_repos_in_cwe:
                    continue

                # Global repo limit
                repo_count = sum(1 for c in selected if self.get_repo_key(c) == repo)
                if repo_count >= max_per_repo:
                    continue

                selected.append(cve)
                selected_ids.add(cve_id_str)
                selected_cwe_counts[cwe_id] += 1
                selected_repos.add(repo)
                seen_repos_in_cwe.add(repo)
                count += 1

            if verbose and count > 0:
                cwe_name = self.CWE_TOP25[cwe_id]['name'][:30]
                print(f"  {cwe_id}: {count} - {cwe_name}")

        if verbose:
            print(f"\nPhase 1 selected: {len(selected)}")

        # ==================== Phase 2: Composite Scoring ====================
        remaining_target = target_count - len(selected)

        if remaining_target <= 0:
            if verbose:
                print("Target reached, skipping Phase 2")
            return selected[:target_count]

        if verbose:
            print(f"\n=== Phase 2: {remaining_target} remaining slots, composite scoring ===")

        # Build candidate pool (exclude already selected)
        remaining_cves = [c for c in candidates if c['cve_id'] not in selected_ids]

        # Calculate composite score for each CVE
        for cve in remaining_cves:
            cwe_id = cve.get('cwe_id', '')
            repo = self.get_repo_key(cve)

            # 1. Base score: Reproducibility score (0-120)
            base_score = cve.get('score', 0)

            # 2. Importance bonus: CWE danger score (0-57) -> normalized to (0-30)
            importance = self.get_cwe_importance(cwe_id)
            importance_bonus = (importance / 57) * 30

            # 3. CVSS bonus: (0-10) -> (0-20)
            cvss = cve.get('cvss_score', 0)
            cvss_bonus = cvss * 2

            # 4. Diversity bonus: Fewer selected = higher bonus (0-20)
            cwe_count = selected_cwe_counts.get(cwe_id, 0)
            if cwe_count == 0:
                diversity_bonus = 20  # New CWE, highest bonus
            elif cwe_count < 3:
                diversity_bonus = 10
            else:
                diversity_bonus = 0

            # 5. Novelty bonus: New repo gets bonus (0-10)
            novelty_bonus = 10 if repo not in selected_repos else 0

            # Composite score
            cve['_final_score'] = base_score + importance_bonus + cvss_bonus + diversity_bonus + novelty_bonus
            cve['_score_breakdown'] = {
                'base': base_score,
                'importance': importance_bonus,
                'cvss': cvss_bonus,
                'diversity': diversity_bonus,
                'novelty': novelty_bonus
            }

        # Sort by composite score
        remaining_cves.sort(key=lambda x: x.get('_final_score', 0), reverse=True)

        # Greedy selection
        for cve in remaining_cves:
            if len(selected) >= target_count:
                break

            cwe_id = cve.get('cwe_id', '')
            repo = self.get_repo_key(cve)

            # CWE limit
            if selected_cwe_counts.get(cwe_id, 0) >= max_per_cwe:
                continue

            # Repo limit
            repo_count = sum(1 for c in selected if self.get_repo_key(c) == repo)
            if repo_count >= max_per_repo:
                continue

            selected.append(cve)
            selected_ids.add(cve['cve_id'])
            selected_cwe_counts[cwe_id] += 1
            selected_repos.add(repo)

        if verbose:
            print(f"Phase 2 selected: {len(selected) - (target_count - remaining_target)}")
            print(f"\nTotal selected: {len(selected)}")

            # Print statistics
            self._print_selection_stats(selected)

        # Clean up temporary fields
        for cve in selected:
            cve.pop('_final_score', None)
            cve.pop('_score_breakdown', None)

        return selected

    def _print_selection_stats(self, selected: List[dict]):
        """Print selection statistics"""
        print(f"\n=== Selection Statistics ===")

        # Top 25 coverage
        top25_coverage = Counter()
        non_top25_cwes = Counter()

        for cve in selected:
            cwe_id = cve.get('cwe_id', 'unknown')
            if cwe_id in self.CWE_TOP25:
                top25_coverage[cwe_id] += 1
            else:
                non_top25_cwes[cwe_id] += 1

        print(f"\nTop 25 CWE Coverage: {len(top25_coverage)}/25")
        for cwe_id, count in sorted(top25_coverage.items(), key=lambda x: self.CWE_TOP25.get(x[0], {}).get('rank', 99)):
            rank = self.CWE_TOP25[cwe_id]['rank']
            name = self.CWE_TOP25[cwe_id]['name'][:25]
            print(f"  #{rank:2} {cwe_id}: {count} - {name}")

        print(f"\nNon-Top 25 CWEs: {len(non_top25_cwes)} types, {sum(non_top25_cwes.values())} CVEs")
        for cwe_id, count in non_top25_cwes.most_common(10):
            print(f"  {cwe_id}: {count}")

        # Vendor statistics
        vendors = Counter(cve.get('vendor', 'unknown') for cve in selected)
        print(f"\nUnique vendors: {len(vendors)}")

        # Score statistics
        scores = [cve.get('score', 0) for cve in selected]
        print(f"Reproducibility score: {min(scores)} - {max(scores)}, avg {sum(scores)/len(scores):.1f}")

    def sample_by_month(
        self,
        year: str = '2025',
        target_per_month: int = 100,
        months: Optional[List[str]] = None,
        **kwargs
    ) -> Dict[str, List[dict]]:
        """
        Sample CVEs by month.

        Args:
            year: Year
            target_per_month: Target count per month
            months: List of months e.g., ['2025-07', '2025-08'], None for all
            **kwargs: Parameters passed to smart_sample

        Returns:
            Dict of {month: CVE list}
        """
        # Group by month
        monthly_cves: Dict[str, List[dict]] = defaultdict(list)

        for cve in self.all_cves:
            cve_id = cve.get('cve_id', '')
            if not cve_id.startswith(f'CVE-{year}'):
                continue

            date_published = cve.get('date_published', '')
            if not date_published:
                continue

            month = date_published[:7]
            monthly_cves[month].append(cve)

        # Determine target months
        if months:
            target_months = months
        else:
            target_months = sorted(monthly_cves.keys())

        print(f"=== {year} CVE Monthly Distribution ===\n")
        for month in sorted(monthly_cves.keys()):
            marker = " *" if month in target_months else ""
            print(f"{month}: {len(monthly_cves[month])}{marker}")

        # Sample each month
        results = {}
        for month in target_months:
            if month not in monthly_cves:
                print(f"\n{month} no data, skipping")
                continue

            print(f"\n{'='*60}")
            print(f"Processing {month}")
            print('='*60)

            sampled = self.smart_sample(
                monthly_cves[month],
                target_count=target_per_month,
                **kwargs
            )
            results[month] = sampled

        return results

    def save_sample(self, cves: List[dict], output_path: str):
        """Save sampling results"""
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(cves, f, indent=2, ensure_ascii=False)
        print(f"\nSaved to {output_path}")

    def export_markdown_files(
        self,
        cves: List[dict],
        output_dir: str,
        source_dir: str = 'reproduce_cves_score0'
    ):
        """Export CVE markdown files"""
        import shutil
        os.makedirs(output_dir, exist_ok=True)

        copied = 0
        for cve in cves:
            cve_id = cve['cve_id']
            src = os.path.join(source_dir, f"{cve_id}.md")
            dst = os.path.join(output_dir, f"{cve_id}.md")

            if os.path.exists(src):
                shutil.copy(src, dst)
                copied += 1

        print(f"Copied {copied} MD files to {output_dir}")

    def export_cve_ids(self, cves: List[dict], separator: str = ' ') -> str:
        """Export CVE ID list"""
        return separator.join(cve['cve_id'] for cve in cves)


def main():
    """Demo smart sampling"""
    sampler = DiverseCVESampler()

    # Get CVE-2025
    cve_2025 = [c for c in sampler.all_cves if c.get('cve_id', '').startswith('CVE-2025')]
    print(f"\nCVE-2025 total: {len(cve_2025)}")

    # Smart sample
    print("\n" + "="*60)
    print("Smart Sampling (Top 25 Guarantee + Composite Scoring)")
    print("="*60)

    result = sampler.smart_sample(
        cve_2025,
        target_count=100,
        min_score=30,
        top25_per_cwe=2,
        max_per_cwe=5,
        max_per_repo=2
    )

    # Save results
    sampler.save_sample(result, 'top100_smart_sample.json')

    # Output CVE IDs
    print(f"\nCVE IDs:")
    print(sampler.export_cve_ids(result))


if __name__ == "__main__":
    main()
