# intersection_combinations.py
import functools
from itertools import combinations, chain
from cblearn.datasets import make_all_triplet_indices
import numpy as np
import math

def get_remaining_combinations(A, num_items):
    """
    Get the remaining triplet combinations that are not present in a given set of triplets.

    Args:
        A (list): A list of triplets (each triplet is a tuple of integers). 
                  Items must be integers monotonically increasing from 0 (e.g., range(25)).
        num_items (int): The total number of items to consider for generating all possible combinations.

    Returns:
        numpy.ndarray: An array containing all the remaining triplet combinations 
                       (not present in A) as tuples.
    """
    A_copy = A.copy()
    A_copy = [tuple(sorted(a)) for a in A_copy]
    A_set = set(A_copy)
    U = make_all_triplet_indices(num_items, monotonic=True)
    U = set([tuple(sorted(u)) for u in U])
    remaining = U - A_set
    return np.array(list(remaining))

def can_generate_test_set(train_triplets, num_items, num_test_set):
    """Check if requested test set size can be generated."""
    total_triplets = math.comb(num_items, 3)
    train_set = {tuple(sorted(t)) for t in train_triplets if len(set(t)) == 3}
    remaining = total_triplets - len(train_set)
    return remaining >= num_test_set, remaining

def sample_test_set(train_triplets, num_items, num_test_set, all=False, adaptive=True, seed=69):
    """Samples test triplets with adaptive strategy for efficiency.
    
    Args:
        train_triplets: Array of training triplets
        num_items: Total number of items
        num_test_set: Number of test triplets to sample
        all: If True, return all remaining triplets. If False, sample.
        adaptive: Auto-choose between 'all' and sampling based on remaining count
        
    Returns:
        Array of test triplets
    """
    np.random.seed(seed)
    if adaptive:
        # Calculate remaining triplets without full generation
        total_triplets = math.comb(num_items, 3)
        train_set = {tuple(sorted(t)) for t in train_triplets if len(set(t)) == 3}
        remaining = total_triplets - len(train_set)
        
        # Auto-select strategy
        if remaining <= 10_000:
            return get_remaining_combinations(train_triplets, num_items)
        else:
            all = False  # Force sampling for large remaining sets
            if num_test_set > remaining:
                raise ValueError(f"Cannot sample {num_test_set} test triplets. Only {remaining} available.")

    if all:
        return get_remaining_combinations(train_triplets, num_items)
    else:
        # Batched vectorized sampling implementation
        train_set = {tuple(sorted(t)) for t in train_triplets if len(set(t)) == 3}
        sampled = set()
        batch_size = 10_000
        max_attempts = 100
        
        # Precompute hash indices for fast lookups
        train_hashes = {t[0]*num_items**2 + t[1]*num_items + t[2] for t in train_set}
        sampled_hashes = set()

        for _ in range(max_attempts):
            # Vectorized candidate generation
            candidates = np.random.randint(0, num_items, (batch_size, 3))
            candidates.sort(axis=1)
            
            # Valid triplets filter (i < j < k with unique elements)
            valid_mask = (candidates[:, 0] < candidates[:, 1]) & (candidates[:, 1] < candidates[:, 2])
            candidates = candidates[valid_mask]
            
            if candidates.size == 0:
                continue
                
            # Hash-based duplicate filtering
            candidate_hashes = candidates[:, 0]*num_items**2 + candidates[:, 1]*num_items + candidates[:, 2]
            unique_hashes, unique_idx = np.unique(candidate_hashes, return_index=True)
            candidates = candidates[unique_idx]
            
            # Exclude training and already sampled triplets
            mask = ~np.isin(unique_hashes, list(train_hashes | sampled_hashes))
            new_candidates = candidates[mask]
            
            # Update sampled set
            for t in new_candidates:
                sampled.add(tuple(t))
                sampled_hashes.add(t[0]*num_items**2 + t[1]*num_items + t[2])
                if len(sampled) == num_test_set:
                    return np.array(sorted(sampled))
        
        raise RuntimeError(f"Failed to sample {num_test_set} triplets after {max_attempts} attempts")

