import os
import glob
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
import pyarrow as pa
import pyarrow.parquet as pq
from typing import List, Dict, Any, Tuple
from tqdm import tqdm
import torch
import gc
import json
import pickle

def path_to_string(path: List[List[str]]) -> str:
    """
    Convert path to string format
    path: [["entity1", "relation1", "entity2"], ["entity2", "relation2", "entity3"]]
    Convert to: "entity1 -> relation1 -> entity2 -> relation2 -> entity3"
    """
    if not path or len(path) == 0:
        return ""
    
    # Extract all elements, avoiding duplicate intermediate entities
    elements = []
    for i, triple in enumerate(path):
        if len(triple) != 3:
            continue
        if i == 0:
            # First triple, add all three elements
            elements.extend(triple)
        else:
            # Subsequent triples, only add relation and tail entity (avoid duplicate head entity)
            elements.extend(triple[1:])
    
    return " -> ".join(elements)

def compute_similarity_scores(question_embedding: np.ndarray, 
                            path_embeddings: np.ndarray) -> np.ndarray:
    """
    Compute cosine similarity between question embedding and path embeddings
    """
    # Calculate cosine similarity
    question_norm = np.linalg.norm(question_embedding)
    path_norms = np.linalg.norm(path_embeddings, axis=1)
    
    # Avoid division by zero
    question_norm = max(question_norm, 1e-8)
    path_norms = np.maximum(path_norms, 1e-8)
    
    # Calculate cosine similarity
    similarities = np.dot(path_embeddings, question_embedding) / (path_norms * question_norm)
    return similarities

def rank_paths_by_similarity(question: str, 
                           paths: List[List[List[str]]], 
                           model: SentenceTransformer) -> Tuple[List[List[List[str]]], List[float]]:
    """
    Rank paths by similarity to the question
    """
    if not paths or len(paths) == 0:
        return [], []
    
    try:
        # Encode the question
        question_embedding = model.encode([question], convert_to_numpy=True)[0]
        
        # Convert paths to strings and encode
        path_strings = [path_to_string(path) for path in paths]
        
        # Filter empty strings and overly long strings
        valid_indices = []
        for i, s in enumerate(path_strings):
            if s.strip() and len(s) < 10000:  # Limit string length
                valid_indices.append(i)
        
        if not valid_indices:
            return [], []
        
        valid_path_strings = [path_strings[i] for i in valid_indices]
        valid_paths = [paths[i] for i in valid_indices]
        
        # Process path encoding in batches, using smaller batch size
        batch_size = 8  # Reduce batch size
        path_embeddings = []
        
        for i in range(0, len(valid_path_strings), batch_size):
            try:
                batch_strings = valid_path_strings[i:i+batch_size]
                # Further filter strings in the batch
                filtered_batch = []
                for s in batch_strings:
                    if len(s.strip()) > 0 and len(s) < 5000:
                        filtered_batch.append(s)
                    else:
                        filtered_batch.append("empty path")  # Replace with placeholder
                
                batch_embeddings = model.encode(filtered_batch, convert_to_numpy=True, 
                                              show_progress_bar=False, batch_size=4)
                path_embeddings.append(batch_embeddings)
                
                # Clear GPU cache
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                print(f"Error encoding batch {i//batch_size}: {str(e)}")
                # Create zero vectors as fallback
                embedding_dim = model.get_sentence_embedding_dimension()
                fallback_embeddings = np.zeros((len(batch_strings), embedding_dim))
                path_embeddings.append(fallback_embeddings)
        
        if not path_embeddings:
            return [], []
            
        path_embeddings = np.vstack(path_embeddings)
        
        # Calculate similarity scores
        similarity_scores = compute_similarity_scores(question_embedding, path_embeddings)
        
        # Sort by similarity (descending order)
        sorted_indices = np.argsort(similarity_scores)[::-1]
        
        # Reorder paths and scores
        ranked_paths = [valid_paths[i] for i in sorted_indices]
        ranked_scores = [float(similarity_scores[i]) for i in sorted_indices]
        
        return ranked_paths, ranked_scores
        
    except Exception as e:
        print(f"Error in rank_paths_by_similarity: {str(e)}")
        return paths, [0.0] * len(paths)  # Return original paths and zero scores

def save_dataset_safely(dataset: Dataset, output_path: str) -> bool:
    """
    Safely save dataset using the most compatible method
    """
    base_path = output_path.replace('.arrow', '')
    
    # Method 1: Use HuggingFace datasets standard save method
    try:
        print("Attempting to save using HuggingFace datasets format...")
        dataset_dir = f"{base_path}.arrow"
        
        # Save to temporary directory first
        temp_dir = dataset_dir + ".tmp"
        if os.path.exists(temp_dir):
            import shutil
            shutil.rmtree(temp_dir)
            
        # Use datasets standard save method
        dataset.save_to_disk(temp_dir)
        
        # Verify save
        test_dataset = Dataset.load_from_disk(temp_dir)
        assert len(test_dataset) == len(dataset)
        assert test_dataset.column_names == dataset.column_names
        
        # Rename to final directory
        if os.path.exists(dataset_dir):
            import shutil
            shutil.rmtree(dataset_dir)
        os.rename(temp_dir, dataset_dir)
        
        print(f"Successfully saved as HuggingFace dataset: {dataset_dir}")
        return True
        
    except Exception as e:
        print(f"HuggingFace dataset save failed: {str(e)}")
        # Clean up temporary directory
        temp_dir = f"{base_path}.arrow.tmp"
        if os.path.exists(temp_dir):
            import shutil
            shutil.rmtree(temp_dir)
    
    # Method 2: Save in batches as multiple small files
    try:
        print("Attempting to save as multiple batch files...")
        batch_size = 500  # Use smaller batch size
        total_samples = len(dataset)
        
        saved_files = []
        for i in range(0, total_samples, batch_size):
            end_idx = min(i + batch_size, total_samples)
            batch_dataset = dataset.select(range(i, end_idx))
            batch_file = f"{base_path}_batch_{i//batch_size:04d}.arrow"
            
            # Save batch file
            temp_batch_dir = batch_file + ".tmp"
            if os.path.exists(temp_batch_dir):
                import shutil
                shutil.rmtree(temp_batch_dir)
            
            batch_dataset.save_to_disk(temp_batch_dir)
            
            # Verify batch file
            test_batch = Dataset.load_from_disk(temp_batch_dir)
            assert len(test_batch) == (end_idx - i)
            
            # Rename
            if os.path.exists(batch_file):
                import shutil
                shutil.rmtree(batch_file)
            os.rename(temp_batch_dir, batch_file)
            
            saved_files.append(batch_file)
            
            if (i // batch_size + 1) % 5 == 0:
                print(f"Saved batch {i//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}")
                # Clean up memory
                gc.collect()
        
        print(f"Successfully saved as {len(saved_files)} batch files")
        return True
        
    except Exception as e:
        print(f"Batch save failed: {str(e)}")
        # Clean up possible temporary files
        for temp_file in glob.glob(f"{base_path}_batch_*.arrow.tmp"):
            if os.path.exists(temp_file):
                import shutil
                shutil.rmtree(temp_file)
    
    # Method 3: Save as Parquet format (more stable)
    try:
        print("Attempting to save as Parquet format...")
        parquet_file = f"{base_path}.parquet"
        
        # Save as parquet
        temp_parquet = parquet_file + ".tmp"
        dataset.to_parquet(temp_parquet)
        
        # Verify
        test_dataset = Dataset.from_parquet(temp_parquet)
        assert len(test_dataset) == len(dataset)
        
        # Rename
        if os.path.exists(parquet_file):
            os.remove(parquet_file)
        os.rename(temp_parquet, parquet_file)
        
        print(f"Successfully saved as Parquet: {parquet_file}")
        return True
        
    except Exception as e:
        print(f"Parquet save failed: {str(e)}")
        temp_parquet = f"{base_path}.parquet.tmp"
        if os.path.exists(temp_parquet):
            os.remove(temp_parquet)
    
    # Method 4: Save as JSON Lines format (last resort)
    try:
        print("Attempting to save as JSON Lines...")
        jsonl_file = f"{base_path}.jsonl"
        
        temp_jsonl = jsonl_file + ".tmp"
        with open(temp_jsonl, 'w', encoding='utf-8') as f:
            for i, sample in enumerate(dataset):
                json.dump(sample, f, ensure_ascii=False)
                f.write('\n')
                if (i + 1) % 1000 == 0:
                    f.flush()  # Periodically flush buffer
        
        # Rename
        if os.path.exists(jsonl_file):
            os.remove(jsonl_file)
        os.rename(temp_jsonl, jsonl_file)
                    
        print(f"Successfully saved as JSONL: {jsonl_file}")
        return True
        
    except Exception as e:
        print(f"JSONL save failed: {str(e)}")
        temp_jsonl = f"{base_path}.jsonl.tmp"
        if os.path.exists(temp_jsonl):
            os.remove(temp_jsonl)
    
    return False

def process_dataset_file(file_path: str, model: SentenceTransformer, output_dir: str):
    """
    Process a single arrow file
    """
    print(f"Processing file: {file_path}")
    
    try:
        # Read dataset
        dataset = Dataset.from_file(file_path)
    except Exception as e:
        print(f"Failed to read dataset: {str(e)}")
        return
    
    print(f"Original dataset columns: {dataset.column_names}")
    print(f"Dataset size: {len(dataset)}")
    
    # Check disk space
    import shutil
    free_space = shutil.disk_usage(output_dir).free / (1024**3)  # GB
    print(f"Available disk space: {free_space:.2f} GB")
    
    if free_space < 10:  # If available space is less than 10GB
        print("WARNING: Low disk space, consider cleaning up")
    
    # Process large datasets in batches
    batch_size = 500  # Reduce batch size
    total_samples = len(dataset)
    
    all_processed_data = []
    
    for batch_start in range(0, total_samples, batch_size):
        batch_end = min(batch_start + batch_size, total_samples)
        print(f"\nProcessing batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}")
        print(f"Samples {batch_start} to {batch_end-1}")
        
        batch_data = {}
        
        # Initialize batch data
        for column_name in dataset.column_names:
            batch_data[column_name] = []
        batch_data['two_hop_paths_scores'] = []
        batch_data['three_hop_paths_scores'] = []
        
        # Process each sample in the batch
        for i in range(batch_start, batch_end):
            try:
                sample = dataset[i]
                
                # Copy all original fields
                for column_name in dataset.column_names:
                    if column_name not in ['two_hop_paths', 'three_hop_paths']:
                        batch_data[column_name].append(sample[column_name])
                
                question = sample['question']
                two_hop_paths = sample.get('two_hop_paths', [])
                three_hop_paths = sample.get('three_hop_paths', [])
                
                # Process two_hop_paths
                if two_hop_paths and len(two_hop_paths) > 0:
                    ranked_two_hop, scores_two_hop = rank_paths_by_similarity(
                        question, two_hop_paths, model
                    )
                else:
                    ranked_two_hop, scores_two_hop = [], []
                
                # Process three_hop_paths
                if three_hop_paths and len(three_hop_paths) > 0:
                    ranked_three_hop, scores_three_hop = rank_paths_by_similarity(
                        question, three_hop_paths, model
                    )
                else:
                    ranked_three_hop, scores_three_hop = [], []
                
                # Add processed data
                batch_data['two_hop_paths'].append(ranked_two_hop)
                batch_data['three_hop_paths'].append(ranked_three_hop)
                batch_data['two_hop_paths_scores'].append(scores_two_hop)
                batch_data['three_hop_paths_scores'].append(scores_three_hop)
                
                if (i - batch_start + 1) % 50 == 0:
                    print(f"  Processed {i - batch_start + 1}/{batch_end - batch_start} samples in current batch")
                    # Clean up memory
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        
            except Exception as e:
                print(f"Error processing sample {i}: {str(e)}")
                # Add empty data to maintain consistency
                for column_name in dataset.column_names:
                    if column_name not in ['two_hop_paths', 'three_hop_paths']:
                        batch_data[column_name].append(sample.get(column_name, None))
                batch_data['two_hop_paths'].append([])
                batch_data['three_hop_paths'].append([])
                batch_data['two_hop_paths_scores'].append([])
                batch_data['three_hop_paths_scores'].append([])
                continue
        
        # Add batch data to total data
        all_processed_data.append(batch_data)
        
        # Clean up memory
        del batch_data
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Merge all batch data
    print("\nMerging all batches...")
    final_data = {}
    
    # Initialize final data structure
    first_batch = all_processed_data[0]
    for key in first_batch.keys():
        final_data[key] = []
    
    # Merge data
    for batch_data in all_processed_data:
        for key, values in batch_data.items():
            final_data[key].extend(values)
    
    # Clean up batch data
    del all_processed_data
    gc.collect()
    
    # Verify data integrity
    print("Data integrity check:")
    data_lengths = {}
    for key, values in final_data.items():
        data_lengths[key] = len(values)
        print(f"  {key}: {len(values)} samples")
    
    # Check if all field lengths are consistent
    lengths = list(data_lengths.values())
    if len(set(lengths)) > 1:
        print("ERROR: Inconsistent data lengths detected!")
        for key, length in data_lengths.items():
            if length != lengths[0]:
                print(f"  {key} has {length} samples, expected {lengths[0]}")
        return
    
    try:
        # Create new dataset
        print("Creating final dataset...")
        new_dataset = Dataset.from_dict(final_data)
        print(f"Final dataset columns: {new_dataset.column_names}")
        print(f"Final dataset size: {len(new_dataset)}")
        
        # Clean up memory
        del final_data
        gc.collect()

        output_file = os.path.join(output_dir, os.path.basename(file_path))
        output_arrow_file = output_file.replace('.arrow', '_ranked.arrow')
        
        # Safe save
        success = save_dataset_safely(new_dataset, output_arrow_file)
        
        if success:
            print(f"Successfully saved processed file")
        else:
            print("Failed to save the processed dataset")
            
    except Exception as e:
        print(f"Error creating final dataset: {str(e)}")
        import traceback
        traceback.print_exc()

def load_processed_dataset(file_path: str) -> Dataset:
    """
    Load processed dataset, supporting multiple formats
    """
    base_path = file_path.replace('.arrow', '').replace('_ranked', '')
    base_name = os.path.basename(base_path)
    dir_path = os.path.dirname(base_path)
    
    # Try to load HuggingFace dataset format
    dataset_dir = f"{base_path}_ranked.arrow"
    if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir):
        print(f"Loading from HuggingFace dataset: {dataset_dir}")
        return Dataset.load_from_disk(dataset_dir)
    
    # Try to load batch files
    batch_pattern = os.path.join(dir_path, f"{base_name}_ranked_batch_*.arrow")
    batch_dirs = glob.glob(batch_pattern)
    if batch_dirs:
        print(f"Loading from batch directories: {len(batch_dirs)} batches")
        datasets = []
        for bd in sorted(batch_dirs):
            if os.path.isdir(bd):
                datasets.append(Dataset.load_from_disk(bd))
        
        if datasets:
            from datasets import concatenate_datasets
            return concatenate_datasets(datasets)
    
    # Try to load Parquet file
    parquet_file = f"{base_path}_ranked.parquet"
    if os.path.exists(parquet_file):
        print(f"Loading from Parquet: {parquet_file}")
        return Dataset.from_parquet(parquet_file)
    
    # Try to load JSONL format
    jsonl_file = f"{base_path}_ranked.jsonl"
    if os.path.exists(jsonl_file):
        print(f"Loading from JSONL: {jsonl_file}")
        import json
        data = []
        with open(jsonl_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
        return Dataset.from_list(data)
    
    raise FileNotFoundError(f"No processed dataset found for {file_path}")

def verify_processed_dataset(file_path: str):
    """
    Verify processed dataset
    """
    try:
        dataset = load_processed_dataset(file_path)
        print(f"Successfully loaded dataset with {len(dataset)} samples")
        print(f"Dataset type: {type(dataset)}")
        print(f"Columns: {dataset.column_names}")
        
        # Check new fields
        required_fields = ['two_hop_paths_scores', 'three_hop_paths_scores']
        for field in required_fields:
            if field in dataset.column_names:
                print(f"✓ {field} field present")
            else:
                print(f"✗ {field} field missing")
        
        # Check first sample
        if len(dataset) > 0:
            sample = dataset[0]
            print("\nFirst sample verification:")
            print(f"  two_hop_paths: {len(sample.get('two_hop_paths', []))} paths")
            print(f"  three_hop_paths: {len(sample.get('three_hop_paths', []))} paths")
            print(f"  two_hop_paths_scores: {len(sample.get('two_hop_paths_scores', []))} scores")
            print(f"  three_hop_paths_scores: {len(sample.get('three_hop_paths_scores', []))} scores")
            
            # Display some score examples
            if sample.get('two_hop_paths_scores'):
                scores = sample['two_hop_paths_scores'][:3]
                print(f"  Sample two_hop scores: {scores}")
            if sample.get('three_hop_paths_scores'):
                scores = sample['three_hop_paths_scores'][:3]
                print(f"  Sample three_hop scores: {scores}")
        
        return True
        
    except Exception as e:
        print(f"Verification failed: {str(e)}")
        import traceback
        traceback.print_exc()
        return False

def main():
    # Configuration parameters
    model_path = "sentence-transformers/all-MiniLM-L6-v2"
    data_dir = "data/shortest_path_index/RoG-cwq/test/added_2_3_hop_paths"
    output_dir = "data/shortest_path_index/RoG-cwq/test/ranked_paths"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Check if verification of processed files is needed
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == "--verify":
        print("Verification mode")
        arrow_files = glob.glob(os.path.join(data_dir, "*.arrow"))
        for file_path in arrow_files:
            print(f"\n{'='*60}")
            print(f"Verifying: {os.path.basename(file_path)}")
            print(f"{'='*60}")
            verify_processed_dataset(file_path)
        return
    
    # Load sentence-transformer model
    print("Loading SentenceTransformer model...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Check GPU memory
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        print(f"GPU memory: {gpu_memory:.2f} GB")
    
    try:
        model = SentenceTransformer(model_path, device=device)
        print("Model loaded successfully!")
        print(f"Model embedding dimension: {model.get_sentence_embedding_dimension()}")
    except Exception as e:
        print(f"Failed to load model: {str(e)}")
        return
    
    # Get all arrow files
    arrow_files = glob.glob(os.path.join(data_dir, "*.arrow"))
    
    if not arrow_files:
        print(f"No .arrow files found in {data_dir}")
        return
    
    print(f"Found {len(arrow_files)} arrow files to process:")
    for file_path in arrow_files:
        file_size = os.path.getsize(file_path) / (1024**2)  # MB
        print(f"  - {os.path.basename(file_path)} ({file_size:.1f} MB)")
    
    # Sort by file size, process small files first
    arrow_files.sort(key=lambda x: os.path.getsize(x))
    
    # Process each file
    for i, file_path in enumerate(arrow_files):
        try:
            print(f"\n{'='*80}")
            print(f"Processing {i+1}/{len(arrow_files)}: {os.path.basename(file_path)}")
            print(f"{'='*80}")
            
            # Check if already processed
            base_name = os.path.basename(file_path).replace('.arrow', '')
            possible_outputs = [
                os.path.join(output_dir, f"{base_name}_ranked.arrow"),      # HuggingFace dataset directory
                os.path.join(output_dir, f"{base_name}_ranked_batch_0000.arrow"),  # Batch directory
                os.path.join(output_dir, f"{base_name}_ranked.parquet"),    # Parquet file
                os.path.join(output_dir, f"{base_name}_ranked.jsonl")       # JSONL file
            ]
            
            already_processed = any(os.path.exists(path) for path in possible_outputs)
            if already_processed:
                print(f"File already processed, skipping...")
                continue
            
            process_dataset_file(file_path, model, output_dir)
            
            # Clean up memory
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            print(f"Completed: {os.path.basename(file_path)}")
            
            # Immediately verify processing results
            print("Verifying processed file...")
            if verify_processed_dataset(file_path.replace(data_dir, output_dir)):
                print("✓ Verification successful")
            else:
                print("✗ Verification failed")
            
        except KeyboardInterrupt:
            print("\nProcess interrupted by user")
            break
        except Exception as e:
            print(f"Error processing file {file_path}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    print("\nProcessing completed!")
    print(f"Results saved in: {output_dir}")

def test_path_conversion():
    """
    Test path conversion function
    """
    test_path = [
        ["Apennine Mountains", "geography.mountain_range.mountains", "Monte Catria"], 
        ["Monte Catria", "common.topic.image", "Monte Catria - Visto da Ripalta di Arcevia"]
    ]
    
    result = path_to_string(test_path)
    expected = "Apennine Mountains -> geography.mountain_range.mountains -> Monte Catria -> common.topic.image -> Monte Catria - Visto da Ripalta di Arcevia"
    
    print("Test path conversion:")
    print(f"Input: {test_path}")
    print(f"Output: {result}")
    print(f"Expected: {expected}")
    print(f"Match: {result == expected}")

if __name__ == "__main__":
    # Run test
    test_path_conversion()
    print("\n" + "="*50 + "\n")
    
    # Run main program
    main()