import torch
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import pickle
import json
import os
import argparse
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings("ignore")

class TrainingEmbeddingsGenerator:
    def __init__(self, 
                 sentence_transformer_name: str,
                 batch_size: int = 64,
                 device_ids: List[int] = None):

        self.sentence_transformer_name = sentence_transformer_name
        self.batch_size = batch_size
        
        # Set up devices
        self.device_count = torch.cuda.device_count()
        if device_ids is None:
            self.device_ids = list(range(self.device_count))
        else:
            self.device_ids = device_ids
            
        self.primary_device = f"cuda:{self.device_ids[0]}" if self.device_ids else "cpu"
        
        print(f"Available GPUs: {self.device_count}")
        print(f"Using device: {self.primary_device}")
        
        # Load sentence transformer
        print(f"Loading sentence transformer: {sentence_transformer_name}")
        self.model = SentenceTransformer(sentence_transformer_name, device=self.primary_device)
        print("Sentence transformer loaded successfully")
        
    def load_training_data(self, file_path: str, text_column: str, label_columns: List[str]) -> pd.DataFrame:

        print(f"Loading training data from: {file_path}")
        
        # Detect file format and load accordingly
        if file_path.endswith('.csv'):
            df = pd.read_csv(file_path)
        elif file_path.endswith('.json'):
            df = pd.read_json(file_path)
        elif file_path.endswith('.jsonl'):
            df = pd.read_json(file_path, lines=True)
        elif file_path.endswith('.parquet'):
            df = pd.read_parquet(file_path)
        else:
            raise ValueError(f"Unsupported file format: {file_path}")
            
        print(f"Loaded {len(df)} training samples")
        print(f"Columns: {df.columns.tolist()}")
        print(f"Text column: {text_column}")
        print(f"Label columns: {label_columns}")
        
        # Validate columns exist
        missing_cols = [col for col in [text_column] + label_columns if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing columns in dataset: {missing_cols}")
            
        return df
    
    def compute_embeddings(self, texts: List[str]) -> np.ndarray:

        print(f"Computing embeddings for {len(texts)} texts...")
        
        embeddings = []
        num_batches = (len(texts) + self.batch_size - 1) // self.batch_size
        
        for i in tqdm(range(0, len(texts), self.batch_size), 
                     desc="Computing embeddings", 
                     total=num_batches):
            batch_texts = texts[i:i + self.batch_size]
            
            # Filter out empty texts
            batch_texts = [text if isinstance(text, str) and text.strip() else "" for text in batch_texts]
            
            try:
                batch_embeddings = self.model.encode(
                    batch_texts, 
                    convert_to_tensor=False,
                    show_progress_bar=False
                )
                embeddings.append(batch_embeddings)
            except Exception as e:
                print(f"Warning: Error computing embeddings for batch {i//self.batch_size}: {e}")
                # Create zero embeddings for failed batch
                embedding_dim = self.model.get_sentence_embedding_dimension()
                zero_embeddings = np.zeros((len(batch_texts), embedding_dim))
                embeddings.append(zero_embeddings)
        
        final_embeddings = np.vstack(embeddings)
        print(f"Computed embeddings shape: {final_embeddings.shape}")
        
        return final_embeddings
    
    def extract_label_combinations(self, df: pd.DataFrame, label_columns: List[str]) -> List[Tuple]:

        label_combinations = []
        for _, row in df.iterrows():
            combo = tuple(row[col] for col in label_columns)
            label_combinations.append(combo)
        return label_combinations
    
    def save_embeddings(self, 
                       embeddings: np.ndarray,
                       texts: List[str],
                       label_combinations: List[Tuple],
                       metadata: Dict,
                       output_path: str):

        print(f"Saving embeddings to: {output_path}")
        
        # Prepare data to save
        save_data = {
            'embeddings': embeddings,
            'texts': texts,
            'label_combinations': label_combinations,
            'metadata': metadata,
            'embedding_dimension': embeddings.shape[1],
            'num_samples': len(texts)
        }
        
        # Save as pickle for efficiency with large arrays
        if output_path.endswith('.pkl') or output_path.endswith('.pickle'):
            with open(output_path, 'wb') as f:
                pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            # Default to pickle
            output_path = output_path + '.pkl'
            with open(output_path, 'wb') as f:
                pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"Embeddings saved successfully!")
        print(f"  - Shape: {embeddings.shape}")
        print(f"  - File size: {os.path.getsize(output_path) / (1024**3):.2f} GB")
    
    def process_training_data(self,
                            input_file: str,
                            text_column: str,
                            label_columns: List[str],
                            output_file: str):

        print("=" * 50)
        print("TRAINING DATA EMBEDDINGS GENERATOR")
        print("=" * 50)
        
        # Load training data
        df = self.load_training_data(input_file, text_column, label_columns)
        
        # Extract texts and labels
        texts = df[text_column].tolist()
        label_combinations = self.extract_label_combinations(df, label_columns)
        
        # Compute embeddings
        embeddings = self.compute_embeddings(texts)
        
        # Prepare metadata
        metadata = {
            'sentence_transformer_model': self.sentence_transformer_name,
            'text_column': text_column,
            'label_columns': label_columns,
            'batch_size': self.batch_size,
            'input_file': input_file,
            'total_samples': len(texts),
            'unique_label_combinations': len(set(label_combinations))
        }
        
        # Save embeddings
        self.save_embeddings(embeddings, texts, label_combinations, metadata, output_file)
        
        print("=" * 50)
        print("PROCESSING COMPLETED SUCCESSFULLY!")
        print("=" * 50)

def load_training_embeddings(file_path: str) -> Dict:

    print(f"Loading pre-computed embeddings from: {file_path}")
    
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    print(f"Loaded embeddings:")
    print(f"  - Shape: {data['embeddings'].shape}")
    print(f"  - Model: {data['metadata']['sentence_transformer_model']}")
    print(f"  - Samples: {data['num_samples']}")
    
    return data

def main():
    parser = argparse.ArgumentParser(description='Generate embeddings for training data')
    
    parser.add_argument('--input_file', type=str, default="yelp_train.csv",
                       help='Path to training data file')
    parser.add_argument('--output_file', type=str, default="yelp_train_embeddings.pkl",
                       help='Path to output embeddings file')
    parser.add_argument('--text_column', type=str, default="text",
                       help='Name of the text column')
    parser.add_argument('--label_columns', type=str, default=['label1', 'label2', 'label3', 'label4', 'label5'],
                       help='Names of label columns')
    parser.add_argument('--sentence_transformer', type=str, 
                       default='models--sentence-transformers--stsb-roberta-base-v2',
                       help='Sentence transformer model name')
    parser.add_argument('--batch_size', type=int, default=16,
                       help='Batch size for embedding computation')
    parser.add_argument('--device_ids', type=int, nargs='*', default=None,
                       help='GPU device IDs to use')
    
    args = parser.parse_args()
    
    # Initialize generator
    generator = TrainingEmbeddingsGenerator(
        sentence_transformer_name=args.sentence_transformer,
        batch_size=args.batch_size,
        device_ids=args.device_ids
    )
    
    # Process training data
    generator.process_training_data(
        input_file=args.input_file,
        text_column=args.text_column,
        label_columns=args.label_columns,
        output_file=args.output_file
    )


if __name__ == "__main__":

    main()
