import os
import torch
import pickle
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import re
from typing import List, Tuple

model = SentenceTransformer('BAAI/bge-large-en-v1.5')
tokenizer = model.tokenizer

def parse_table_file(filepath: str) -> Tuple[str, List[str], List[List[str]]]:
    """
    Parse a table file and extract table name, headers, and rows.
    
    Args:
        filepath (str): Path to the table file.
        
    Returns:
        Tuple[str, List[str], List[List[str]]]: Table name, headers, and rows.
    """
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read().strip()
    
    lines = content.split('\n')
    
    table_name = ""
    for line in lines:
        if line.startswith("Table:"):
            table_name = line.replace("Table:", "").strip()
            break
    
    headers = []
    rows = []
    
    found_separator = False
    for line in lines:
        if not line.strip() or line.startswith("Table:"):
            continue
        
        if re.match(r'^-+$', line.strip()):
            found_separator = True
            continue
        
        if '|' in line:
            parts = [part.strip() for part in line.split('|')]
            parts = [part for part in parts if part]
            
            if not found_separator and not headers:
                headers = parts
            elif found_separator:
                if len(parts) == len(headers):
                    rows.append(parts)
    
    return table_name, headers, rows

def create_dense_row_text(headers: List[str], row: List[str], table_name: str = "") -> str:
    """
    Create a textual representation of a table row for embedding.
    Uses the format: "[column name] is [cell value], [column name] is [cell value]"
    
    Args:
        headers (List[str]): Column headers.
        row (List[str]): Row values.
        table_name (str): Name of the table.
        
    Returns:
        str: Textual representation of the row.
    """
    pairs = []
    for header, value in zip(headers, row):
        if value and value.strip():
            pairs.append(f"{header} is {value}")
    
    row_text = ", ".join(pairs)
    
    if table_name:
        row_text = f"{table_name}: {row_text}"
    
    return row_text

def extract_table_feats(input_path, output_path, batch_size=256, num_splits=4, split_index=None, disable_prog=False, include_table_name=True):
    """
    Extract dense row embeddings for all table files in a directory or from a parquet file and save them as a pickle file.
    Uses row serialization format: "[column name] is [cell value], [column name] is [cell value]"

    Args:
        input_path (str): Path to the directory containing table files or path to a .parquet file.
        output_path (str): Path to save the pickle file.
        batch_size (int): Batch size for encoding.
        num_splits (int): Number of splits to divide the total files into.
        split_index (int, optional): Index of the split to process (0-based).
        disable_prog (bool): Whether to disable tqdm progress bars.
        include_table_name (bool): Whether to include table name in row text.
    """
    
    if input_path.endswith('.parquet'):
        df = pd.read_parquet(input_path)

        if 'table_text' in df.columns:
            all_texts = df['table_text'].fillna('').astype(str).tolist()
            all_identifiers = [f"row_{i}" for i in range(len(all_texts))]
        else:
            raise ValueError(f"Parquet file must contain 'table_text' column for table data. Available columns: {list(df.columns)}")
        
        total_items = len(all_texts)
    else:
        table_files = sorted([f for f in os.listdir(input_path) if f.endswith('.txt')])
        all_texts = []
        all_identifiers = []
        
        for filename in table_files:
            filepath = os.path.join(input_path, filename)
            try:
                table_name, headers, rows = parse_table_file(filepath)
                
                for row_idx, row in enumerate(rows):
                    row_text = create_dense_row_text(
                        headers, row, 
                        table_name if include_table_name else ""
                    )
                    all_texts.append(row_text)
                    all_identifiers.append(f"{filename}_{row_idx}")
                    
            except Exception as e:
                print(f"Warning: Failed to process {filename}: {str(e)}")
                continue
        
        total_items = len(all_texts)

    if num_splits <= 0:
        raise ValueError("num_splits must be a positive integer.")
    if split_index is not None and (split_index < 0 or split_index >= num_splits):
        raise ValueError("split_index must be between 0 and num_splits - 1.")

    split_size = (total_items + num_splits - 1) // num_splits

    if split_index is None:
        split_texts = all_texts
        split_identifiers = all_identifiers
    else:
        split_start = split_index * split_size
        split_end = min(split_start + split_size, total_items)
        split_texts = all_texts[split_start:split_end]
        split_identifiers = all_identifiers[split_start:split_end]

    features = {}
    texts = []
    filepaths = []

    for idx, (identifier, text) in enumerate(tqdm(zip(split_identifiers, split_texts), desc=f'Processing {"all items" if split_index is None else f"split {split_index + 1}/{num_splits}"}', disable=disable_prog, total=len(split_identifiers))):
        if input_path.endswith('.parquet'):
            filepath = f"{input_path}:{identifier}"
        else:
            filepath = os.path.join(input_path, identifier)

        texts.append(text)
        filepaths.append(filepath)

    with torch.no_grad():
        encoded_features = model.encode(texts, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=not disable_prog)

    for filepath, feature in zip(filepaths, encoded_features):
        features[filepath] = feature

    if split_index is not None:
        base, ext = os.path.splitext(output_path)
        split_output_path = f"{base}_split{split_index + 1}{ext}"
    else:
        split_output_path = output_path

    os.makedirs(os.path.dirname(split_output_path), exist_ok=True)
    with open(split_output_path, 'wb') as f:
        pickle.dump(features, f)

if __name__ == "__main__":

    import argparse

    parser = argparse.ArgumentParser(description="Extract table features using dense row embeddings and save them as a pickle file.")
    parser.add_argument("--input_path", type=str, help="Path to the directory containing table files or path to a .parquet file.")
    parser.add_argument("--output_path", type=str, help="Path to save the pickle file.")
    parser.add_argument("--num_splits", type=int, default=4, help="Number of splits to divide the total files into.")
    parser.add_argument("--split_index", type=int, help="Index of the split to process (0-based).")
    parser.add_argument("--disable_prog", action="store_true", help="Disable progress bars.")
    parser.add_argument("--include_table_name", action="store_true", default=True, help="Include table name in row text representation.")
    args = parser.parse_args()

    extract_table_feats(
        args.input_path,
        args.output_path,
        num_splits=args.num_splits,
        split_index=args.split_index,
        disable_prog=args.disable_prog,
        include_table_name=args.include_table_name
    )