import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
from itertools import product
import json
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from typing import List, Dict, Tuple, Any
import random
import warnings
import faiss
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")


class SyntheticDatasetPipeline:
    def __init__(self, 
                 llm_model_name: str,
                 sentence_transformer_name: str,
                 num_synthetic_samples: int,
                 num_final_samples: int,
                 k_neighbors: int = 5,
                 batch_size: int = 32,
                 max_length: int = 128,
                 temperature: float = 0.8,
                 top_p: float = 0.9):

        self.llm_model_name = llm_model_name
        self.sentence_transformer_name = sentence_transformer_name
        self.num_synthetic_samples = num_synthetic_samples
        self.num_final_samples = num_final_samples
        self.k_neighbors = k_neighbors
        self.batch_size = batch_size
        self.max_length = max_length
        self.temperature = temperature
        self.top_p = top_p
        
        self.device_count = torch.cuda.device_count()
        print(f"Available GPUs: {self.device_count}")
        
    def load_training_data(self, file_path: str, text_column: str, label_columns: List[str]) -> pd.DataFrame:

        # Detect file format and load accordingly
        if file_path.endswith('.csv'):
            df = pd.read_csv(file_path)
        elif file_path.endswith('.json'):
            df = pd.read_json(file_path)
        elif file_path.endswith('.jsonl'):
            df = pd.read_json(file_path, lines=True)
        else:
            raise ValueError(f"Unsupported file format: {file_path}")
            
        # Store format information for saving later
        self.data_format = {
            'file_extension': file_path.split('.')[-1],
            'columns': df.columns.tolist(),
            'dtypes': df.dtypes.to_dict()
        }
        
        self.text_column = text_column
        self.label_columns = label_columns
        
        print(f"Loaded {len(df)} training samples")
        print(f"Text column: {text_column}")
        print(f"Label columns: {label_columns}")
        
        return df
    
    def analyze_label_combinations(self, df: pd.DataFrame) -> Dict[Tuple, int]:

        label_combinations = []
        for _, row in df.iterrows():
            combo = tuple(row[col] for col in self.label_columns)
            label_combinations.append(combo)
            
        combo_counts = Counter(label_combinations)
        
        print(f"Found {len(combo_counts)} unique label combinations")
        for combo, count in combo_counts.most_common(10):
            print(f"  {combo}: {count}")
            
        return dict(combo_counts)
    
    def calculate_samples_per_combination(self, combo_counts: Dict[Tuple, int]) -> Dict[Tuple, int]:

        total_training = sum(combo_counts.values())
        samples_per_combo = {}
        
        for combo, count in combo_counts.items():
            proportion = count / total_training
            num_samples = int(proportion * self.num_synthetic_samples)
            # Ensure at least 1 sample per combination
            num_samples = max(1, num_samples)
            samples_per_combo[combo] = num_samples
            
        # Adjust to meet exact total
        current_total = sum(samples_per_combo.values())
        if current_total != self.num_synthetic_samples:
            # Distribute the difference proportionally
            diff = self.num_synthetic_samples - current_total
            sorted_combos = sorted(combo_counts.items(), key=lambda x: x[1], reverse=True)
            
            for i, (combo, _) in enumerate(sorted_combos):
                if diff == 0:
                    break
                if diff > 0:
                    samples_per_combo[combo] += 1
                    diff -= 1
                else:
                    if samples_per_combo[combo] > 1:
                        samples_per_combo[combo] -= 1
                        diff += 1
                        
        print(f"Samples per combination (top 10):")
        sorted_samples = sorted(samples_per_combo.items(), key=lambda x: x[1], reverse=True)
        for combo, num_samples in sorted_samples[:10]:
            print(f"  {combo}: {num_samples}")
            
        return samples_per_combo

class LLMTextGenerator:
    def __init__(self, model_name: str, device_ids: List[int]):
        self.model_name = model_name
        self.device_ids = device_ids
        self.primary_device = f"cuda:{device_ids[0]}"


        
        # Setup LLM with multi-GPU support
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        if self.tokenizer.pad_token_id:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=self.tokenizer.eos_token_id, torch_dtype=torch.float16)
        
        self.model.resize_token_embeddings(len(self.tokenizer))


        self.model = torch.nn.DataParallel(self.model)
        self.model.to('cuda')
        self.is_parallel = True

            
    def generate_batch(self, prompts: List[str], max_length: int = 128, 
                      temperature: float = 0.8, top_p: float = 0.9) -> List[str]:

        # Tokenize prompts
        inputs = self.tokenizer(
            prompts, 
            return_tensors="pt", 
        ).to(self.primary_device)
        
        # Generate
        with torch.no_grad():
            if self.is_parallel:

                outputs = self.model.module.generate(
                    **inputs,
                    max_length=max_length,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    num_return_sequences=1
                )
            else:
    
                outputs = self.model.generate(
                    **inputs,
                    max_length=max_length,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    num_return_sequences=1
                )
        
        # Decode outputs
        generated_texts = []
        for i, output in enumerate(outputs):
            prompt_length = inputs['input_ids'][i].shape[0]
            generated_part = output[prompt_length:]
            text = self.tokenizer.decode(generated_part, skip_special_tokens=True)
            generated_texts.append(text.strip())
            
        return generated_texts

def generate_synthetic_data(pipeline: SyntheticDatasetPipeline, 
                          combo_counts: Dict[Tuple, int],
                          samples_per_combo: Dict[Tuple, int]) -> List[Dict]:

    # Split GPUs for LLM (use first 4 GPUs)
    llm_device_ids = list(range(min(4, pipeline.device_count)))
    
    # Initialize LLM generator
    generator = LLMTextGenerator(pipeline.llm_model_name, llm_device_ids)
    
    synthetic_data = []
    
    for combo, num_samples in tqdm(samples_per_combo.items(), desc="Generating synthetic data"):
        if num_samples == 0:
            continue
            
        # Create prompts for this combination
        prompts = []
        for _ in range(num_samples):
            prompt = "\t".join(str(label) for label in combo) + "\n\n"
            prompts.append(prompt)
        
        # Generate in batches
        for i in range(0, len(prompts), pipeline.batch_size):
            batch_prompts = prompts[i:i + pipeline.batch_size]
            generated_texts = generator.generate_batch(
                batch_prompts,
                max_length=pipeline.max_length,
                temperature=pipeline.temperature,
                top_p=pipeline.top_p
            )
            
            # Create synthetic samples
            for text in generated_texts:
                sample = {pipeline.text_column: text}
                for j, col in enumerate(pipeline.label_columns):
                    sample[col] = combo[j]
                synthetic_data.append(sample)
    
    print(f"Generated {len(synthetic_data)} synthetic samples")
    return synthetic_data

def compute_embeddings(texts: List[str], model_name: str, device_ids: List[int], 
                      batch_size: int = 64) -> np.ndarray:

    primary_device = f"cuda:{device_ids[0]}"
    
    # Load sentence transformer
    model = SentenceTransformer(model_name, device=primary_device)
    
    # Compute embeddings in batches
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Computing embeddings"):
        batch_texts = texts[i:i + batch_size]
        batch_embeddings = model.encode(batch_texts, convert_to_tensor=False)
        embeddings.append(batch_embeddings)
    
    return np.vstack(embeddings)

def find_knn_and_vote_optimized(training_embeddings: np.ndarray,
                               synthetic_embeddings: np.ndarray,
                               training_labels: List[Tuple],
                               synthetic_labels: List[Tuple],
                               k: int,
                               use_gpu: bool = True) -> Dict[int, int]:

    votes = defaultdict(int)
    
    # Group samples by label combination
    print("Grouping samples by label combinations...")
    training_by_label = defaultdict(list)
    synthetic_by_label = defaultdict(list)
    
    for idx, label_combo in enumerate(training_labels):
        training_by_label[label_combo].append(idx)
        
    for idx, label_combo in enumerate(synthetic_labels):
        synthetic_by_label[label_combo].append(idx)
    
    # Get unique label combinations that exist in both training and synthetic
    common_labels = set(training_by_label.keys()) & set(synthetic_by_label.keys())
    print(f"Found {len(common_labels)} common label combinations")
    
    # Process each label combination separately
    for label_combo in tqdm(common_labels, desc="Processing label combinations"):
        train_indices = training_by_label[label_combo]
        synth_indices = synthetic_by_label[label_combo]
        
        if len(train_indices) == 0 or len(synth_indices) == 0:
            continue
            
        print(f"Processing {label_combo}: {len(train_indices)} training, {len(synth_indices)} synthetic")
        
        # Extract embeddings for this label combination
        train_embs = training_embeddings[train_indices]
        synth_embs = synthetic_embeddings[synth_indices]
        
        # Normalize embeddings for cosine similarity
        train_embs_norm = train_embs / np.linalg.norm(train_embs, axis=1, keepdims=True)
        synth_embs_norm = synth_embs / np.linalg.norm(synth_embs, axis=1, keepdims=True)
        
        if len(synth_indices) <= k:
            # All synthetic samples get votes
            for synth_idx in synth_indices:
                votes[synth_idx] += len(train_indices)
        else:
            # Use FAISS for fast similarity search
            try:
                # Build FAISS index for synthetic embeddings
                dimension = synth_embs_norm.shape[1]
                
                if use_gpu and hasattr(faiss, 'StandardGpuResources'):
                    # GPU version
                    res = faiss.StandardGpuResources()
                    index = faiss.GpuIndexFlatIP(res, dimension)
                else:
                    # CPU version (still faster than sklearn for large datasets)
                    index = faiss.IndexFlatIP(dimension)
                
                # Add synthetic embeddings to index (Inner Product = cosine similarity for normalized vectors)
                index.add(synth_embs_norm.astype(np.float32))
                
                # Search for k nearest neighbors for each training sample
                similarities, indices = index.search(train_embs_norm.astype(np.float32), k)
                
                # Record votes
                for i, top_k_indices in enumerate(indices):
                    for local_idx in top_k_indices:
                        if local_idx < len(synth_indices):  # Valid index check
                            global_synth_idx = synth_indices[local_idx]
                            votes[global_synth_idx] += 1
                            
            except Exception as e:
                print(f"FAISS failed for {label_combo}, falling back to sklearn: {e}")
                # Fallback to sklearn
                similarities = cosine_similarity(train_embs_norm, synth_embs_norm)
                
                for i, train_similarities in enumerate(similarities):
                    top_k_positions = np.argsort(train_similarities)[-k:]
                    for pos in top_k_positions:
                        global_synth_idx = synth_indices[pos]
                        votes[global_synth_idx] += 1
    
    print(f"Voting complete. {len(votes)} synthetic samples received votes.")
    return dict(votes)

def sample_final_dataset_by_combination(synthetic_data: List[Dict],
                                       votes: Dict[int, int],
                                       label_columns: List[str],
                                       num_final_samples: int,
                                       preserve_distribution: bool = True) -> List[Dict]:

    print("Sampling final dataset by label combination...")
    
    # Group synthetic data by label combination
    synth_by_label = defaultdict(list)
    for idx, sample in enumerate(synthetic_data):
        label_combo = tuple(sample[col] for col in label_columns)
        synth_by_label[label_combo].append(idx)
    
    # Calculate target samples per combination
    if preserve_distribution:
        # Preserve the distribution from synthetic data
        total_synthetic = len(synthetic_data)
        samples_per_combo = {}
        
        for combo, indices in synth_by_label.items():
            proportion = len(indices) / total_synthetic
            target_samples = int(proportion * num_final_samples)
            samples_per_combo[combo] = max(1, target_samples)  # At least 1 sample
        
        # Adjust to meet exact total
        current_total = sum(samples_per_combo.values())
        if current_total != num_final_samples:
            diff = num_final_samples - current_total
            # Distribute difference proportionally
            sorted_combos = sorted(synth_by_label.items(), key=lambda x: len(x[1]), reverse=True)
            
            for combo, indices in sorted_combos:
                if diff == 0:
                    break
                if diff > 0:
                    samples_per_combo[combo] += 1
                    diff -= 1
                elif samples_per_combo[combo] > 1:
                    samples_per_combo[combo] -= 1
                    diff += 1
    else:
        # Equal samples per combination
        num_combos = len(synth_by_label)
        base_samples = num_final_samples // num_combos
        extra_samples = num_final_samples % num_combos
        
        samples_per_combo = {}
        for i, combo in enumerate(synth_by_label.keys()):
            samples_per_combo[combo] = base_samples + (1 if i < extra_samples else 0)
    
    print(f"Target samples per combination (top 10):")
    sorted_targets = sorted(samples_per_combo.items(), key=lambda x: x[1], reverse=True)
    for combo, target in sorted_targets[:10]:
        print(f"  {combo}: {target}")
    
    # Sample from each combination
    final_data = []
    
    for combo, target_samples in tqdm(samples_per_combo.items(), desc="Sampling by combination"):
        candidate_indices = synth_by_label[combo]
        
        if len(candidate_indices) == 0:
            continue
            
        if len(candidate_indices) <= target_samples:
            # Take all available samples
            selected_indices = candidate_indices
        else:
            # Sample based on votes
            weights = [votes.get(idx, 0) + 1 for idx in candidate_indices]  # +1 to avoid zero weights
            
            # Normalize weights
            total_weight = sum(weights)
            if total_weight > 0:
                probabilities = [w / total_weight for w in weights]
                
                # Sample without replacement
                selected_indices = np.random.choice(
                    candidate_indices,
                    size=target_samples,
                    replace=False,
                    p=probabilities
                ).tolist()
            else:
                # If no votes, sample randomly
                selected_indices = np.random.choice(
                    candidate_indices,
                    size=target_samples,
                    replace=False
                ).tolist()
        
        # Add selected samples to final data
        for idx in selected_indices:
            final_data.append(synthetic_data[idx])
    
    print(f"Sampled {len(final_data)} final synthetic samples")
    
    # Verify label distribution
    final_combo_counts = defaultdict(int)
    for sample in final_data:
        combo = tuple(sample[col] for col in label_columns)
        final_combo_counts[combo] += 1
    
    print("Final label distribution (top 10):")
    sorted_final = sorted(final_combo_counts.items(), key=lambda x: x[1], reverse=True)
    for combo, count in sorted_final[:10]:
        print(f"  {combo}: {count}")
    
    return final_data

def save_synthetic_data(data: List[Dict], output_path: str, data_format: Dict):

    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Ensure columns are in the same order as training data
    df = df[data_format['columns']]
    
    # Convert dtypes to match training data
    for col, dtype in data_format['dtypes'].items():
        if col in df.columns:
            try:
                df[col] = df[col].astype(dtype)
            except:
                print(f"Warning: Could not convert column {col} to {dtype}")
    
    # Save in the same format
    file_ext = data_format['file_extension']
    if file_ext == 'csv':
        df.to_csv(output_path, index=False)
    elif file_ext == 'json':
        df.to_json(output_path, orient='records', indent=2)
    elif file_ext == 'jsonl':
        df.to_json(output_path, orient='records', lines=True)
    
    print(f"Saved {len(df)} synthetic samples to {output_path}")


def load_training_embeddings(file_path: str) -> Dict:

    print(f"Loading pre-computed embeddings from: {file_path}")
    
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    print(f"Loaded embeddings:")
    print(f"  - Shape: {data['embeddings'].shape}")
    print(f"  - Model: {data['metadata']['sentence_transformer_model']}")
    print(f"  - Samples: {data['num_samples']}")
    
    return data



def visualize_vote_profiles(synthetic_data: List[Dict],
                          votes: Dict[int, int],
                          label_columns: List[str],
                          output_dir: str = "vote_profiles",
                          num_combinations: int = 2,
                          figsize: Tuple[int, int] = (12, 8)):

    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Creating vote profile visualizations...")
    
    # Group synthetic data by label combination
    synth_by_label = defaultdict(list)
    for idx, sample in enumerate(synthetic_data):
        label_combo = tuple(sample[col] for col in label_columns)
        synth_by_label[label_combo].append(idx)
    
    # Calculate vote statistics for each combination
    combo_stats = {}
    for combo, indices in synth_by_label.items():
        combo_votes = [votes.get(idx, 0) for idx in indices]
        combo_stats[combo] = {
            'indices': indices,
            'votes': combo_votes,
            'total_votes': sum(combo_votes),
            'avg_votes': np.mean(combo_votes),
            'max_votes': max(combo_votes),
            'min_votes': min(combo_votes),
            'std_votes': np.std(combo_votes),
            'samples_with_votes': sum(1 for v in combo_votes if v > 0)
        }
    
    # Select top combinations by total votes
    sorted_combos = sorted(combo_stats.items(), 
                          key=lambda x: x[1]['total_votes'], 
                          reverse=True)
    
    selected_combos = sorted_combos[:num_combinations]
    
    print(f"Selected {len(selected_combos)} combinations for visualization:")
    for combo, stats in selected_combos:
        print(f"  {combo}: {stats['total_votes']} total votes, "
              f"{stats['avg_votes']:.2f} avg, "
              f"{stats['samples_with_votes']}/{len(stats['indices'])} samples with votes")
    
    # Create visualizations
    for i, (combo, stats) in enumerate(selected_combos):
        # Create figure with subplots
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
        fig.suptitle(f'Vote Profile for Label Combination: {combo}', fontsize=14, fontweight='bold')
        
        indices = stats['indices']
        combo_votes = stats['votes']
        
        # 1. Scatter plot: Sample indices vs votes
        ax1.scatter(indices, combo_votes, alpha=0.6, s=20)
        ax1.set_xlabel('Sample Index')
        ax1.set_ylabel('Number of Votes')
        ax1.set_title('Votes vs Sample Index')
        ax1.grid(True, alpha=0.3)
        
        # Add statistics text
        stats_text = f'Total: {stats["total_votes"]}\nAvg: {stats["avg_votes"]:.2f}\nStd: {stats["std_votes"]:.2f}'
        ax1.text(0.02, 0.98, stats_text, transform=ax1.transAxes, 
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        # 2. Histogram of vote distribution
        bins = max(10, min(50, len(set(combo_votes))))
        ax2.hist(combo_votes, bins=bins, alpha=0.7, edgecolor='black')
        ax2.set_xlabel('Number of Votes')
        ax2.set_ylabel('Frequency')
        ax2.set_title('Vote Distribution Histogram')
        ax2.grid(True, alpha=0.3)
        
        # Add vertical lines for mean and median
        mean_votes = np.mean(combo_votes)
        median_votes = np.median(combo_votes)
        ax2.axvline(mean_votes, color='red', linestyle='--', label=f'Mean: {mean_votes:.2f}')
        ax2.axvline(median_votes, color='orange', linestyle='--', label=f'Median: {median_votes:.2f}')
        ax2.legend()
        
        # 3. Cumulative distribution
        sorted_votes = np.sort(combo_votes)
        cumulative_pct = np.arange(1, len(sorted_votes) + 1) / len(sorted_votes) * 100
        ax3.plot(sorted_votes, cumulative_pct, linewidth=2)
        ax3.set_xlabel('Number of Votes')
        ax3.set_ylabel('Cumulative Percentage')
        ax3.set_title('Cumulative Vote Distribution')
        ax3.grid(True, alpha=0.3)
        
        # Add percentile lines
        p25 = np.percentile(combo_votes, 25)
        p75 = np.percentile(combo_votes, 75)
        ax3.axvline(p25, color='green', linestyle=':', label=f'25th: {p25:.1f}')
        ax3.axvline(p75, color='purple', linestyle=':', label=f'75th: {p75:.1f}')
        ax3.legend()
        
        # 4. Box plot and violin plot combination
        ax4.boxplot(combo_votes, vert=True, patch_artist=True,
                   boxprops=dict(facecolor='lightblue', alpha=0.7))
        ax4.set_ylabel('Number of Votes')
        ax4.set_title('Vote Distribution Box Plot')
        ax4.grid(True, alpha=0.3)
        
        # Add sample count and zero-vote percentage
        zero_votes = sum(1 for v in combo_votes if v == 0)
        zero_pct = zero_votes / len(combo_votes) * 100
        info_text = f'Samples: {len(combo_votes)}\nZero votes: {zero_votes} ({zero_pct:.1f}%)'
        ax4.text(0.02, 0.98, info_text, transform=ax4.transAxes, 
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
        
        plt.tight_layout()
        
        # Save the figure
        combo_str = "_".join(str(c) for c in combo)
        filename = f"vote_profile_combo_{i+1}_{combo_str}.png"
        filepath = os.path.join(output_dir, filename)
        plt.savefig(filepath, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Saved: {filepath}")
    
    # Create a summary comparison plot
    if len(selected_combos) > 1:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        fig.suptitle('Vote Profile Comparison Across Label Combinations', fontsize=14, fontweight='bold')
        
        # Comparison of vote distributions
        combo_names = []
        all_votes_data = []
        
        for combo, stats in selected_combos:
            combo_name = f"{combo}"[:30] + "..." if len(str(combo)) > 30 else str(combo)
            combo_names.append(combo_name)
            all_votes_data.append(stats['votes'])
        
        # Box plot comparison
        ax1.boxplot(all_votes_data, labels=combo_names, patch_artist=True)
        ax1.set_ylabel('Number of Votes')
        ax1.set_title('Vote Distribution Comparison')
        ax1.tick_params(axis='x', rotation=45)
        ax1.grid(True, alpha=0.3)
        
        # Statistics comparison
        stats_metrics = ['total_votes', 'avg_votes', 'max_votes', 'samples_with_votes']
        x_pos = np.arange(len(combo_names))
        width = 0.2
        
        for j, metric in enumerate(stats_metrics):
            values = [selected_combos[i][1][metric] for i in range(len(selected_combos))]
            ax2.bar(x_pos + j * width, values, width, label=metric.replace('_', ' ').title())
        
        ax2.set_xlabel('Label Combinations')
        ax2.set_ylabel('Count')
        ax2.set_title('Vote Statistics Comparison')
        ax2.set_xticks(x_pos + width * 1.5)
        ax2.set_xticklabels(combo_names, rotation=45)
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save comparison plot
        comparison_filepath = os.path.join(output_dir, "vote_profile_comparison.png")
        plt.savefig(comparison_filepath, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Saved comparison: {comparison_filepath}")
    
    # Generate summary statistics file
    summary_filepath = os.path.join(output_dir, "vote_profile_summary.txt")
    with open(summary_filepath, 'w') as f:
        f.write("VOTE PROFILE ANALYSIS SUMMARY\n")
        f.write("=" * 50 + "\n\n")
        
        f.write(f"Total synthetic samples: {len(synthetic_data)}\n")
        f.write(f"Total label combinations: {len(synth_by_label)}\n")
        f.write(f"Samples with votes: {len([v for v in votes.values() if v > 0])}\n")
        f.write(f"Total votes cast: {sum(votes.values())}\n\n")
        
        f.write("SELECTED COMBINATIONS FOR VISUALIZATION:\n")
        f.write("-" * 40 + "\n")
        
        for i, (combo, stats) in enumerate(selected_combos):
            f.write(f"\n{i+1}. Label Combination: {combo}\n")
            f.write(f"   Number of samples: {len(stats['indices'])}\n")
            f.write(f"   Total votes: {stats['total_votes']}\n")
            f.write(f"   Average votes per sample: {stats['avg_votes']:.3f}\n")
            f.write(f"   Standard deviation: {stats['std_votes']:.3f}\n")
            f.write(f"   Min votes: {stats['min_votes']}\n")
            f.write(f"   Max votes: {stats['max_votes']}\n")
            f.write(f"   Samples with votes: {stats['samples_with_votes']} ({stats['samples_with_votes']/len(stats['indices'])*100:.1f}%)\n")
            f.write(f"   Samples with zero votes: {len(stats['indices']) - stats['samples_with_votes']} ({(len(stats['indices']) - stats['samples_with_votes'])/len(stats['indices'])*100:.1f}%)\n")
    
    print(f"Saved summary: {summary_filepath}")
    print(f"\nVisualization complete! Files saved in: {output_dir}")


def add_gaussian_noise_to_votes(votes: Dict[int, int], 
                               sigma: float,
                               total_synthetic_samples: int = None,
                               random_seed: int = None) -> Dict[int, int]:
 
    if random_seed is not None:
        np.random.seed(random_seed)
    
    print(f"Adding Gaussian noise to votes (σ = {sigma})")
    
    # Create noisy votes dictionary
    noisy_votes = {}
    
    # Add noise to existing votes
    original_vote_count = len(votes)
    negative_count = 0
    
    for idx, vote_count in votes.items():
        # Add Gaussian noise
        noise = np.random.normal(0, sigma)
        noisy_vote = vote_count + noise
        
        # Clip negative values to zero
        if noisy_vote < 0:
            noisy_votes[idx] = 0
            negative_count += 1
        else:
            noisy_votes[idx] = int(round(noisy_vote))
    
    # Optionally add noise to samples that originally had zero votes
    if total_synthetic_samples is not None:
        zero_vote_samples = set(range(total_synthetic_samples)) - set(votes.keys())
        zero_noise_count = 0
        positive_from_zero = 0
        
        for idx in zero_vote_samples:
            # Add noise to zero vote count
            noise = np.random.normal(0, sigma)
            noisy_vote = 0 + noise
            
            if noisy_vote > 0:
                noisy_votes[idx] = int(round(noisy_vote))
                positive_from_zero += 1
            else:
                # Keep as zero (don't store in dictionary to save memory)
                zero_noise_count += 1
        
        print(f"Zero-vote samples processed: {len(zero_vote_samples)}")
        print(f"  - Remained zero after noise: {zero_noise_count}")
        print(f"  - Became positive after noise: {positive_from_zero}")
    
    print(f"Noise addition completed:")
    print(f"  - Original samples with votes: {original_vote_count}")
    print(f"  - Samples clipped to zero: {negative_count}")
    print(f"  - Final samples with non-zero votes: {len(noisy_votes)}")
    
    return noisy_votes






def main():
    """Main execution function."""
    
    # Configuration
    config = {
        'llm_model_name': "gpt2",  # Replace with your LLM
        'sentence_transformer_name': "stsb-roberta-base-v2",  # Replace with your sentence transformer
        'training_data_path': "yelp_train.csv",  # Replace with your training data path
        'output_path': 'synthetic_data.csv',  # Output path
        'text_column': 'text',  # Name of text column
        'label_columns': ['label1', 'label2'],  # List of label columns
        'num_synthetic_samples': 50000,  # Total synthetic samples to generate
        'num_final_samples': 10000,  # Final samples after filtering
        'k_neighbors': 5,  # Number of nearest neighbors for voting
        'batch_size': 16,  # Batch size
        'max_length': 128,  # Maximum sequence length
        'temperature': 1,  # Generation temperature
        'top_p': 0.9,  # Top-p for nucleus sampling
        'training_embeddings_path': 'yelp_train_embeddings.pkl', # Need to generate embeddings using generate_embedding.py
        'sigma': 1, # Variance of noise
    }
    
    # Initialize pipeline
    pipeline = SyntheticDatasetPipeline(
        llm_model_name=config['llm_model_name'],
        sentence_transformer_name=config['sentence_transformer_name'],
        num_synthetic_samples=config['num_synthetic_samples'],
        num_final_samples=config['num_final_samples'],
        k_neighbors=config['k_neighbors'],
        batch_size=config['batch_size'],
        max_length=config['max_length'],
        temperature=config['temperature'],
        top_p=config['top_p']
    )
    
    print("=== Step 1: Load and analyze training data ===")
    print("Using pre-computed training embeddings...")
    embeddings_data = load_training_embeddings(config['training_embeddings_path'])
    
    # Extract data from embeddings file
    training_embeddings = embeddings_data['embeddings']
    training_texts = embeddings_data['texts']
    training_labels = embeddings_data['label_combinations']
    
    # Verify sentence transformer compatibility
    if embeddings_data['metadata']['sentence_transformer_model'] != config['sentence_transformer_name']:
        print(f"Warning: Embeddings were computed with {embeddings_data['metadata']['sentence_transformer_model']}")
        print(f"But current config uses {config['sentence_transformer_name']}")
        print("This may cause compatibility issues.")
    
    # Reconstruct combo_counts from training labels
    combo_counts = {}
    for label_combo in training_labels:
        combo_counts[label_combo] = combo_counts.get(label_combo, 0) + 1
        
    # Set pipeline attributes from loaded data
    pipeline.text_column = embeddings_data['metadata']['text_column']
    pipeline.label_columns = embeddings_data['metadata']['label_columns']
    
    # Create dummy data format for saving (will be updated from actual data structure)
    pipeline.data_format = {
        'file_extension': 'csv',
        'columns': [pipeline.text_column] + pipeline.label_columns,
        'dtypes': {pipeline.text_column: 'object'}
    }
    for col in pipeline.label_columns:
        pipeline.data_format['dtypes'][col] = 'object'
            
    
    samples_per_combo = pipeline.calculate_samples_per_combination(combo_counts)





    
    print("\n=== Step 2: Generate synthetic data ===")
    synthetic_data = generate_synthetic_data(pipeline, combo_counts, samples_per_combo)
    
    print("\n=== Step 3: Compute embeddings ===")
    # Extract texts and labels
    synthetic_texts = [sample[config['text_column']] for sample in synthetic_data]
    synthetic_labels = [tuple(sample[col] for col in config['label_columns']) 
                       for sample in synthetic_data]
    
    # Use remaining GPUs for embeddings (or reuse GPU 0 if only one available)
    embed_device_ids = list(range(1, pipeline.device_count)) if pipeline.device_count > 1 else [0]
    
    
    print("Computing synthetic embeddings...")
    synthetic_embeddings = compute_embeddings(
        synthetic_texts,
        config['sentence_transformer_name'],
        embed_device_ids,
        config['batch_size']
    )
    
    print("\n=== Step 4: Compute K-NN votes ===")
    votes = find_knn_and_vote_optimized(
        training_embeddings,
        synthetic_embeddings,
        training_labels,
        synthetic_labels,
        k=config['k_neighbors'],
        use_gpu=True  # Set to False if FAISS-GPU not available
    )

    visualize_vote_profiles(
    synthetic_data=synthetic_data,
    votes=votes,
    label_columns=config['label_columns'],
    output_dir="vote_analysis",
    num_combinations=20
)

    noisy_votes = add_gaussian_noise_to_votes(
        votes=votes,
        sigma=config['sigma'],
        total_synthetic_samples=len(synthetic_data),
        random_seed=42
    )


    print("\n=== Step 5: Sample final dataset ===")
    final_synthetic_data = sample_final_dataset_by_combination(
    synthetic_data,
    noisy_votes,
    label_columns=config['label_columns'],  # Your label columns
    num_final_samples=config['num_final_samples'],
    preserve_distribution=True  # Preserve synthetic data distribution
)
    
    print("\n=== Step 6: Save synthetic data ===")
    save_synthetic_data(final_synthetic_data, config['output_path'], pipeline.data_format)
    
    print("\n=== Pipeline completed successfully! ===")

if __name__ == "__main__":
    # Set multiprocessing start method
    mp.set_start_method('spawn', force=True)
    main()