from datasets import Dataset, load_from_disk, concatenate_datasets
from collections import defaultdict, deque
import os
from tqdm import tqdm
import pyarrow as pa
import gc
import glob

def save_single_batch_to_disk(batch_samples, output_dir, batch_idx):
    """
    Save a single batch to disk:
    1. batch_samples : list[dict]  — samples already processed in this batch
    2. output_dir    : str         — overall output directory
    3. batch_idx     : int         — current batch index (used for naming subdirectories)
    """
    try:
        os.makedirs(output_dir, exist_ok=True)
        sub_dir = os.path.join(output_dir, f"batch_{batch_idx:05d}")
        # If the subdirectory already exists, delete it first and then write (ensure overwrite or clean rerun)
        if os.path.exists(sub_dir):
            import shutil
            shutil.rmtree(sub_dir)
        ds = Dataset.from_list(batch_samples)
        ds.save_to_disk(sub_dir, max_shard_size="50MB")
        del ds
        gc.collect()
        print(f"[SAVE] Batch {batch_idx} saved to {sub_dir}  (samples: {len(batch_samples)})")
    except Exception as e:
        print(f"[ERROR] Failed to save batch {batch_idx}: {e}")


def build_graph_index(graph):
    """
    Build an index of the graph for easy path lookup
    graph: list of triplets, each triplet is [subject, relation, object]
    Returns: forward_index (subject -> [(relation, object)]) and all_entities
    """
    forward_index = defaultdict(list)
    all_entities = set()
    
    for triplet in graph:
        if len(triplet) != 3:
            continue
        subject, relation, obj = triplet
        forward_index[subject].append((relation, obj))
        all_entities.add(subject)
        all_entities.add(obj)
    
    return forward_index, all_entities
def path_to_tuple(path):
    """
    Convert a path to a tuple for deduplication
    """
    return tuple(tuple(triplet) for triplet in path)
def tuple_to_path(path_tuple):
    """
    Convert a tuple back to a path
    """
    return [list(triplet) for triplet in path_tuple]
def find_all_k_hop_paths_streaming(topic_entities, graph, k, max_paths=None):
    """
    Find all k-hop paths using streaming and BFS to avoid memory explosion
    
    Args:
        topic_entities: list of topic entity strings
        graph: list of triplets
        k: number of hops (2 or 3)
        max_paths: maximum number of paths, None means no limit
    
    Yields:
        Yield one path at a time instead of returning all paths at once
    """
    # Build adjacency list
    adj = defaultdict(list)
    for triplet in graph:
        if len(triplet) != 3:
            continue
        s, r, o = triplet
        adj[s].append((r, o))
    
    # Set for deduplication
    seen_paths = set()
    paths_found = 0
    
    for start in topic_entities:
        if start not in adj:
            continue
        
        # If the maximum number of paths has been reached, stop searching
        if max_paths is not None and paths_found >= max_paths:
            break
        
        # BFS level-order traversal, each level represents one hop
        # Queue stores: (current entity, path, set of visited entities)
        current_level = [(start, [], frozenset([start]))]
        
        for hop in range(k):
            next_level = []
            
            for current_node, path, visited in current_level:
                # If the maximum number of paths has been reached, stop searching
                if max_paths is not None and paths_found >= max_paths:
                    break
                    
                if current_node not in adj:
                    continue
                
                for relation, next_node in adj[current_node]:
                    # If the maximum number of paths has been reached, stop searching
                    if max_paths is not None and paths_found >= max_paths:
                        break
                        
                    # Prevent cycles: if the next entity has already been visited in the current path, skip
                    if next_node in visited:
                        continue
                    
                    new_triplet = [current_node, relation, next_node]
                    new_path = path + [new_triplet]
                    new_visited = visited | {next_node}
                    
                    if len(new_path) == k:
                        # Completed k-hop path, check deduplication
                        path_tuple = path_to_tuple(new_path)
                        if path_tuple not in seen_paths:
                            seen_paths.add(path_tuple)
                            yield new_path
                            paths_found += 1
                            
                            # If the maximum number of paths has been reached, stop searching
                            if max_paths is not None and paths_found >= max_paths:
                                return
                    else:
                        # Continue expanding
                        next_level.append((next_node, new_path, new_visited))
                
                # If the maximum number of paths has been reached, break out of the inner loop
                if max_paths is not None and paths_found >= max_paths:
                    break
            
            current_level = next_level
            if not current_level:
                break
            
            # If the maximum number of paths has been reached, break out of the hop loop
            if max_paths is not None and paths_found >= max_paths:
                break
def find_k_hop_paths(topic_entities, graph, k, max_paths=None):
    """
    Find all k-hop paths starting from topic entities and perform deduplication
    topic_entities: list of topic entity strings
    graph: list of triplets
    k: number of hops (2 or 3)
    max_paths: maximum number of paths, None means no limit
    Returns: list of paths, each path is a list of k triplets
    """
    if not topic_entities or not graph or k <= 0:
        return []
    
    try:
        # Collect all paths
        all_paths = list(find_all_k_hop_paths_streaming(topic_entities, graph, k, max_paths))
        return all_paths
    except Exception as e:
        print(f"Error in find_k_hop_paths: {e}")
        return []
def process_single_example(example, max_two_hop_paths=None, max_three_hop_paths=None):
    """
    Process a single sample and find all k-hop paths
    
    Args:
        example: a single data sample
        max_two_hop_paths: max number of 2-hop paths, None means no limit
        max_three_hop_paths: max number of 3-hop paths, None means no limit
    """
    try:
        q_entity = example['q_entity']
        graph = example['graph']
        
        print(f"Processing sample with {len(q_entity)} topic entities and {len(graph)} graph edges")
        
        # Directly find all paths and save to corresponding fields
        two_hop_paths = find_k_hop_paths(q_entity, graph, 2, max_two_hop_paths)
        three_hop_paths = find_k_hop_paths(q_entity, graph, 3, max_three_hop_paths)
        
        example['two_hop_paths'] = two_hop_paths
        example['three_hop_paths'] = three_hop_paths
        
        # Indicate whether the limits were triggered
        two_hop_msg = f"{len(two_hop_paths)} 2-hop paths"
        if max_two_hop_paths is not None and len(two_hop_paths) >= max_two_hop_paths:
            two_hop_msg += f" (limited to {max_two_hop_paths})"
            
        three_hop_msg = f"{len(three_hop_paths)} 3-hop paths"
        if max_three_hop_paths is not None and len(three_hop_paths) >= max_three_hop_paths:
            three_hop_msg += f" (limited to {max_three_hop_paths})"
            
        print(f"Found {two_hop_msg} and {three_hop_msg}")
        
        return example
        
    except Exception as e:
        print(f"Error processing example: {e}")
        # Return the original sample with empty paths
        example['two_hop_paths'] = []
        example['three_hop_paths'] = []
        return example
def process_dataset_in_small_batches(
        dataset,
        batch_size=1,
        max_two_hop_paths=None,
        max_three_hop_paths=None,
        incremental_save_dir=None      # ← New: if not None, save incrementally while processing
):
    total_samples = len(dataset)
    print(f"Processing {total_samples} samples in batches of {batch_size}")
    if max_two_hop_paths is not None:
        print(f"2-hop paths limited to: {max_two_hop_paths}")
    if max_three_hop_paths is not None:
        print(f"3-hop paths limited to: {max_three_hop_paths}")

    # Cache all samples only when needed; do not cache all when using incremental saving
    all_processed = [] if incremental_save_dir is None else None

    # For statistics
    total_two_hop, total_three_hop = 0, 0

    for i in tqdm(range(0, total_samples, batch_size), desc="Processing batches"):
        batch_end     = min(i + batch_size, total_samples)
        batch_indices = list(range(i, batch_end))
        batch         = dataset.select(batch_indices)

        batch_processed = []
        for local_idx, example in enumerate(batch):
            print(f"Processing sample {i + local_idx + 1}/{total_samples}")
            processed_example = process_single_example(
                example,
                max_two_hop_paths,
                max_three_hop_paths
            )
            # Statistics
            total_two_hop   += len(processed_example['two_hop_paths'])
            total_three_hop += len(processed_example['three_hop_paths'])
            batch_processed.append(processed_example)

        # ---- Core: save the batch immediately after processing ----
        if incremental_save_dir is not None:
            batch_idx = i // batch_size
            save_single_batch_to_disk(batch_processed, incremental_save_dir, batch_idx)
        else:
            all_processed.extend(batch_processed)

        # Proactively run gc periodically
        if i % (batch_size * 5) == 0:
            gc.collect()

    # Return results and stats
    result_samples = all_processed if all_processed is not None else []
    stats = {
        "total_samples"  : total_samples,
        "total_two_hop"  : total_two_hop,
        "total_three_hop": total_three_hop,
    }
    return result_samples, stats

def save_as_arrow_dataset(processed_samples, output_path):
    """
    Save as an Arrow dataset using the proper HuggingFace datasets method
    """
    print(f"Saving {len(processed_samples)} samples as Arrow dataset...")
    
    try:
        # Create the complete dataset
        dataset = Dataset.from_list(processed_samples)
        
        # Save the entire dataset with a smaller shard size
        dataset.save_to_disk(output_path, max_shard_size="50MB")
        print(f"Successfully saved dataset to {output_path}")
        
    except Exception as e:
        print(f"Error saving complete dataset: {e}")
        print("Trying to save in smaller chunks...")
        
        # If it fails, try manual chunked saving
        save_as_arrow_with_manual_chunking(processed_samples, output_path)
def save_as_arrow_with_manual_chunking(processed_samples, output_path, chunk_size=10):
    """
    Manually save in chunks, creating multiple small dataset files
    """
    print(f"Manually chunking and saving {len(processed_samples)} samples...")
    
    # Ensure the output directory exists
    if os.path.exists(output_path):
        import shutil
        shutil.rmtree(output_path)
    os.makedirs(output_path, exist_ok=True)
    
    total_samples = len(processed_samples)
    successful_chunks = 0
    
    for i in tqdm(range(0, total_samples, chunk_size), desc="Saving chunks"):
        chunk_end = min(i + chunk_size, total_samples)
        chunk_data = processed_samples[i:chunk_end]
        chunk_idx = i // chunk_size
        
        try:
            # Create a small dataset chunk
            chunk_dataset = Dataset.from_list(chunk_data)
            
            # Save as a separate subdirectory
            chunk_path = os.path.join(output_path, f"chunk_{chunk_idx:04d}")
            chunk_dataset.save_to_disk(chunk_path, max_shard_size="10MB")
            
            successful_chunks += 1
            
            # Free memory
            del chunk_dataset
            gc.collect()
            
        except Exception as e:
            print(f"Error saving chunk {chunk_idx} ({i}-{chunk_end}): {e}")
            
            # Try saving sample by sample
            print(f"Retrying with individual samples...")
            
            for j in range(i, chunk_end):
                single_sample_data = [processed_samples[j]]
                single_sample_idx = f"{chunk_idx}_{j-i}"
                
                try:
                    single_sample_dataset = Dataset.from_list(single_sample_data)
                    single_sample_path = os.path.join(output_path, f"chunk_{single_sample_idx}")
                    single_sample_dataset.save_to_disk(single_sample_path, max_shard_size="5MB")
                    successful_chunks += 1
                    del single_sample_dataset
                    gc.collect()
                except Exception as e2:
                    print(f"Failed to save single sample {single_sample_idx}: {e2}")
    
    # Create an index file to record information about all chunks
    chunk_info = {
        "total_samples": total_samples,
        "successful_chunks": successful_chunks,
        "chunk_directories": [d for d in os.listdir(output_path) if d.startswith("chunk_")]
    }
    
    import json
    with open(os.path.join(output_path, "chunks_info.json"), 'w') as f:
        json.dump(chunk_info, f, indent=2)
    
    print(f"Successfully saved {successful_chunks} chunks to {output_path}")
def load_chunked_dataset(dataset_path):
    """
    Load a dataset saved in chunks
    """
    chunks_info_path = os.path.join(dataset_path, "chunks_info.json")
    
    if os.path.exists(chunks_info_path):
        # Load chunked dataset
        import json
        with open(chunks_info_path, 'r') as f:
            chunks_info = json.load(f)
        
        datasets = []
        for chunk_dir in chunks_info["chunk_directories"]:
            chunk_path = os.path.join(dataset_path, chunk_dir)
            if os.path.exists(chunk_path):
                chunk_dataset = load_from_disk(chunk_path)
                datasets.append(chunk_dataset)
        
        if datasets:
            return concatenate_datasets(datasets)
        else:
            return None
    else:
        # Try loading the complete dataset
        try:
            return load_from_disk(dataset_path)
        except:
            return None
        
def process_huggingface_arrow_files(
        file_paths,
        output_path,
        max_two_hop_paths=None,
        max_three_hop_paths=None,
        batch_size=1
):   
    """
    Process multiple HuggingFace Arrow files and save the results
    
    Args:
        file_paths: list of input file paths
        output_path: output path
        max_two_hop_paths: max number of 2-hop paths, None means no limit
        max_three_hop_paths: max number of 3-hop paths, None means no limit
    """
    datasets = []
    
    # Load all files
    for file_path in tqdm(file_paths, desc="Loading files"):
        print(f"Loading file: {file_path}")
        
        if not os.path.exists(file_path):
            print(f"Warning: File {file_path} does not exist! Skipping...")
            continue
            
        try:
            dataset = Dataset.from_file(file_path)
            datasets.append(dataset)
            print(f"Successfully loaded {len(dataset)} samples from {os.path.basename(file_path)}")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            continue
    
    if not datasets:
        print("No datasets were successfully loaded!")
        return
    
    # Concatenate all datasets
    print("Concatenating datasets...")
    combined_dataset = concatenate_datasets(datasets)
    print(f"Total samples: {len(combined_dataset)}")
    
    # Print dataset structure information
    print("Dataset features:", combined_dataset.features)
    if len(combined_dataset) > 0:
        sample = combined_dataset[0]
        print("Sample keys:", list(sample.keys()))
        print(f"Sample q_entity length: {len(sample['q_entity'])}")
        print(f"Sample graph length: {len(sample['graph'])}")
    
    # Process data in batches
    print("Processing samples to find paths...")
    if max_two_hop_paths is not None or max_three_hop_paths is not None:
        print("Path limits enabled:")
        if max_two_hop_paths is not None:
            print(f"  Max 2-hop paths: {max_two_hop_paths}")
        if max_three_hop_paths is not None:
            print(f"  Max 3-hop paths: {max_three_hop_paths}")
    else:
        print("No path limits - finding ALL paths...")
        
    print("Processing samples to find paths...")
    processed_samples, stats = process_dataset_in_small_batches(
        combined_dataset,
        batch_size=batch_size,
        max_two_hop_paths=max_two_hop_paths,
        max_three_hop_paths=max_three_hop_paths,
        incremental_save_dir=output_path        # ← Key: enable incremental saving
    )
    # If incremental saving is used, processed_samples is an empty list;
    # only print statistics, do not perform a "big merge and save as a whole" again.
    print("\n=== Statistics ===")
    print(f"  Total samples        : {stats['total_samples']}")
    print(f"  Total 2-hop paths    : {stats['total_two_hop']}")
    print(f"  Total 3-hop paths    : {stats['total_three_hop']}")
    if stats['total_samples'] > 0:
        print(f"  Avg 2-hop / sample   : {stats['total_two_hop']/stats['total_samples']:.2f}")
        print(f"  Avg 3-hop / sample   : {stats['total_three_hop']/stats['total_samples']:.2f}")
    
    # Show an example of a processed sample
    if processed_samples:
        print("\nSample processed example:")
        sample = processed_samples[0]
        print(f"  q_entity: {sample['q_entity']}")
        print(f"  Number of 2-hop paths: {len(sample['two_hop_paths'])}")
        print(f"  Number of 3-hop paths: {len(sample['three_hop_paths'])}")
        if sample['two_hop_paths']:
            print(f"  First 2-hop path: {sample['two_hop_paths'][0]}")
        
        # Verify deduplication effectiveness
        if len(sample['two_hop_paths']) > 1:
            paths_set = set(path_to_tuple(path) for path in sample['two_hop_paths'])
            print(f"  Unique 2-hop paths verification: {len(paths_set)} == {len(sample['two_hop_paths'])} (should be equal)")

def main():
    # Input file paths
    FILE_PATHS = glob.glob("data/shortest_path_index/RoG-cwq/train/*.arrow")
    
    # Output path - will be saved as a directory
    OUTPUT_PATH = "data/added_2_3_hop_paths/RoG-cwq/train"
    
    # Process files
    process_huggingface_arrow_files(FILE_PATHS, OUTPUT_PATH, max_two_hop_paths=None, max_three_hop_paths=1000, batch_size=1000)

if __name__ == "__main__":
    main()