#!/usr/bin/env python3
"""
Step 5: Export t-SNE coordinates to CSV for plotting
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.manifold import TSNE
import argparse

# Paths
# EMB_DIR = Path('')  # results/3_embedding/embeddings path
# DATA_DIR = Path('')  # results/3_embedding path
# OUTPUT_DIR = Path('')  # results/3_embedding/tsne path
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Model configuration
MODELS = ['grover', 'seq', 'struct', 'full', 'distilled']
MODEL_LABELS = {
    'grover': 'GROVER',
    'seq': 'Seq ',
    'struct': 'Seq+Struct',
    'full': 'Seq+Struct+Reg',
    'distilled': 'Distilled'
}

ANALYSIS_TYPES = ['structural', 'regulatory']
SPLITS = ['train', 'val']


def compute_tsne(embeddings, perplexity=30, random_state=42):
    """Compute t-SNE coordinates"""
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state)
    coords = tsne.fit_transform(embeddings)
    return coords


def main():
    parser = argparse.ArgumentParser(description='Step 5: Export t-SNE Coordinates')
    parser.add_argument('--model', type=str, default='all',
                       choices=['all'] + MODELS, help='Which model to process')
    parser.add_argument('--analysis', type=str, default='all',
                       choices=['all'] + ANALYSIS_TYPES, help='Which analysis to process')
    parser.add_argument('--split', type=str, default='all',
                       choices=['all'] + SPLITS, help='Which split to process')
    parser.add_argument('--perplexity', type=int, default=30, help='t-SNE perplexity')
    args = parser.parse_args()

    models = MODELS if args.model == 'all' else [args.model]
    analyses = ANALYSIS_TYPES if args.analysis == 'all' else [args.analysis]
    splits = SPLITS if args.split == 'all' else [args.split]

    print("=" * 60)
    print("STEP 5: EXPORT t-SNE COORDINATES")
    print("=" * 60)
    print(f"Models: {models}")
    print(f"Analyses: {analyses}")
    print(f"Splits: {splits}")
    print(f"Perplexity: {args.perplexity}")

    all_data = []

    for model in models:
        for analysis in analyses:
            for split in splits:
                print(f"\nProcessing: {model} / {analysis} / {split}")

                # Load embeddings
                emb_file = EMB_DIR / f'{model}_{analysis}_{split}_embeddings.npy'
                if not emb_file.exists():
                    print(f"  Skipping (file not found)")
                    continue

                embeddings = np.load(emb_file)
                print(f"  Embeddings shape: {embeddings.shape}")

                # Load labels
                sample_file = DATA_DIR / f'{analysis}_{split}_samples.tsv'
                samples_df = pd.read_csv(sample_file, sep='\t')

                # Compute t-SNE
                print(f"  Computing t-SNE...")
                coords = compute_tsne(embeddings, perplexity=args.perplexity)

                # Create dataframe
                df = samples_df.copy()
                df['tsne_x'] = coords[:, 0]
                df['tsne_y'] = coords[:, 1]
                df['model'] = model
                df['model_label'] = MODEL_LABELS[model]
                df['analysis'] = analysis
                df['split'] = split

                all_data.append(df)
                print(f"  Done: {len(df)} points")

    # Combine all data
    if all_data:
        combined_df = pd.concat(all_data, ignore_index=True)

        # Save to CSV
        output_file = OUTPUT_DIR / 'tsne_coordinates.csv'
        combined_df.to_csv(output_file, index=False)
        print(f"\n{'=' * 60}")
        print(f"Saved: {output_file}")
        print(f"Total points: {len(combined_df)}")
        print(f"Columns: {list(combined_df.columns)}")
    else:
        print("\nNo data to export.")

    print("\n" + "=" * 60)
    print("STEP 5 COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()
