#!/usr/bin/env python3
"""
Step 4: Evaluate Embedding Quality (GPU)

Calculate all classification metrics:
- Silhouette score
- Separation score (inter/intra class distance ratio)
- Linear probing (macro AUROC/AUPRC) - GPU accelerated
- Per-category one-vs-rest AUROC - GPU accelerated
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import silhouette_score, roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder, label_binarize
import argparse
import warnings
warnings.filterwarnings('ignore')

# Paths
# INPUT_DIR = Path('')  # results/3_embedding path
# EMB_DIR = Path('')  # results/3_embedding/embeddings path
# OUTPUT_DIR = Path('')  # results/3_embedding/metrics path
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Models
MODELS = ['grover', 'seq', 'struct', 'full', 'distilled']


class LogisticRegressionTorch(nn.Module):
    """Logistic regression in PyTorch for GPU acceleration"""
    def __init__(self, input_dim, n_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, n_classes)

    def forward(self, x):
        return self.linear(x)


def train_logistic_gpu(X_train, y_train, X_test, n_classes, device, lr=1.0, epochs=50):
    """Train logistic regression on GPU"""
    input_dim = X_train.shape[1]

    X_train_t = torch.FloatTensor(X_train).to(device)
    y_train_t = torch.LongTensor(y_train).to(device)
    X_test_t = torch.FloatTensor(X_test).to(device)

    model = LogisticRegressionTorch(input_dim, n_classes).to(device)
    optimizer = optim.LBFGS(model.parameters(), lr=lr, max_iter=20)
    criterion = nn.CrossEntropyLoss()

    def closure():
        optimizer.zero_grad()
        outputs = model(X_train_t)
        loss = criterion(outputs, y_train_t)
        # L2 regularization
        l2_reg = 0.5 * sum(p.pow(2).sum() for p in model.parameters())
        loss = loss + l2_reg / len(y_train)
        loss.backward()
        return loss

    for _ in range(epochs):
        optimizer.step(closure)

    model.eval()
    with torch.no_grad():
        logits = model(X_test_t)
        probs = torch.softmax(logits, dim=1).cpu().numpy()

    return probs


def load_data(analysis_type, split):
    """Load samples and embeddings"""
    sample_file = INPUT_DIR / f'{analysis_type}_{split}_samples.tsv'
    df = pd.read_csv(sample_file, sep='\t')

    embeddings = {}
    for model in MODELS:
        emb_file = EMB_DIR / f'{model}_{analysis_type}_{split}_embeddings.npy'
        if emb_file.exists():
            embeddings[model] = np.load(emb_file)

    return df, embeddings


def compute_separation_score(embeddings, labels):
    """Compute inter-class / intra-class distance ratio"""
    unique_labels = np.unique(labels)

    # Intra-class distances
    intra_dists = []
    for label in unique_labels:
        mask = labels == label
        class_embs = embeddings[mask]
        if len(class_embs) > 1:
            # Sample if too many to avoid memory issues
            if len(class_embs) > 500:
                idx = np.random.choice(len(class_embs), 500, replace=False)
                class_embs = class_embs[idx]
            dists = np.linalg.norm(
                class_embs[:, None] - class_embs[None, :],
                axis=2
            )
            intra_dists.append(dists[np.triu_indices_from(dists, k=1)].mean())

    intra_dist = np.mean(intra_dists)

    # Inter-class distances (sample-based for efficiency)
    inter_dists = []
    for i, label1 in enumerate(unique_labels):
        for label2 in unique_labels[i+1:]:
            mask1 = labels == label1
            mask2 = labels == label2
            embs1 = embeddings[mask1]
            embs2 = embeddings[mask2]

            # Sample if too many
            if len(embs1) > 200:
                idx = np.random.choice(len(embs1), 200, replace=False)
                embs1 = embs1[idx]
            if len(embs2) > 200:
                idx = np.random.choice(len(embs2), 200, replace=False)
                embs2 = embs2[idx]

            dists = np.linalg.norm(
                embs1[:, None] - embs2[None, :],
                axis=2
            )
            inter_dists.append(dists.mean())

    inter_dist = np.mean(inter_dists)

    return inter_dist / intra_dist if intra_dist > 0 else 0


def linear_probing_gpu(embeddings, labels, device, n_cv=5, n_seeds=3):
    """Linear probing with cross-validation on GPU"""
    scaler = StandardScaler()
    X = scaler.fit_transform(embeddings)

    le = LabelEncoder()
    y = le.fit_transform(labels)

    classes = np.unique(y)
    n_classes = len(classes)

    auroc_scores = []
    auprc_scores = []

    seeds = [42, 123, 456][:n_seeds]

    for seed in seeds:
        skf = StratifiedKFold(n_splits=n_cv, shuffle=True, random_state=seed)

        fold_aurocs = []
        fold_auprcs = []

        for train_idx, test_idx in skf.split(X, y):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            y_pred_proba = train_logistic_gpu(X_train, y_train, X_test, n_classes, device)

            if n_classes == 2:
                auroc = roc_auc_score(y_test, y_pred_proba[:, 1])
                auprc = average_precision_score(y_test, y_pred_proba[:, 1])
            else:
                y_test_bin = label_binarize(y_test, classes=classes)
                auroc = roc_auc_score(y_test_bin, y_pred_proba, average='macro', multi_class='ovr')
                auprc = average_precision_score(y_test_bin, y_pred_proba, average='macro')

            fold_aurocs.append(auroc)
            fold_auprcs.append(auprc)

        auroc_scores.append(np.mean(fold_aurocs))
        auprc_scores.append(np.mean(fold_auprcs))

    return {
        'auroc_mean': np.mean(auroc_scores),
        'auroc_std': np.std(auroc_scores),
        'auprc_mean': np.mean(auprc_scores),
        'auprc_std': np.std(auprc_scores)
    }


def per_category_auroc_gpu(embeddings, labels, device, n_cv=5):
    """Calculate one-vs-rest AUROC for each category using GPU"""
    scaler = StandardScaler()
    X = scaler.fit_transform(embeddings)

    categories = np.unique(labels)
    results = {}

    for category in categories:
        y_binary = (labels == category).astype(int)

        skf = StratifiedKFold(n_splits=n_cv, shuffle=True, random_state=42)
        fold_aurocs = []

        for train_idx, test_idx in skf.split(X, y_binary):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y_binary[train_idx], y_binary[test_idx]

            try:
                y_pred_proba = train_logistic_gpu(X_train, y_train, X_test, 2, device)
                auroc = roc_auc_score(y_test, y_pred_proba[:, 1])
                fold_aurocs.append(auroc)
            except Exception as e:
                continue

        if fold_aurocs:
            results[category] = {
                'mean': np.mean(fold_aurocs),
                'std': np.std(fold_aurocs)
            }
        else:
            results[category] = {'mean': np.nan, 'std': np.nan}

    return results


def evaluate_model(model_name, embeddings, labels, analysis_type, split, device):
    """Evaluate one model"""
    print(f"\n  {model_name}:")

    results = {
        'model': model_name,
        'analysis': analysis_type,
        'split': split,
        'n_samples': len(embeddings),
        'embedding_dim': embeddings.shape[1]
    }

    # Silhouette score
    print(f"    Computing silhouette score...")
    silhouette = silhouette_score(embeddings, labels)
    results['silhouette'] = silhouette
    print(f"    Silhouette: {silhouette:.4f}")

    # Separation score
    print(f"    Computing separation score...")
    separation = compute_separation_score(embeddings, labels)
    results['separation'] = separation
    print(f"    Separation: {separation:.4f}")

    # Linear probing (GPU)
    print(f"    Computing linear probing (GPU)...")
    lp_results = linear_probing_gpu(embeddings, labels, device)
    results.update(lp_results)
    print(f"    AUROC: {lp_results['auroc_mean']:.4f} ± {lp_results['auroc_std']:.4f}")
    print(f"    AUPRC: {lp_results['auprc_mean']:.4f} ± {lp_results['auprc_std']:.4f}")

    # Per-category AUROC (GPU)
    print(f"    Computing per-category AUROC (GPU)...")
    cat_aurocs = per_category_auroc_gpu(embeddings, labels, device)
    for cat, scores in cat_aurocs.items():
        results[f'auroc_{cat}'] = scores['mean']
        results[f'auroc_{cat}_std'] = scores['std']
    cat_str = ', '.join(f"{cat}: {s['mean']:.4f}±{s['std']:.4f}" for cat, s in cat_aurocs.items())
    print(f"    Per-category: {{{cat_str}}}")

    return results


def main():
    parser = argparse.ArgumentParser(description='Step 4: Evaluate Embeddings (GPU)')
    parser.add_argument('--analysis', type=str, required=True,
                       choices=['structural', 'regulatory', 'all'],
                       help='Which analysis to run')
    parser.add_argument('--split', type=str, required=True,
                       choices=['train', 'val', 'all'],
                       help='Which split to use')
    parser.add_argument('--model', type=str, default='all',
                       choices=['grover', 'seq', 'struct', 'full', 'distilled', 'all'],
                       help='Which model to evaluate')
    parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
    args = parser.parse_args()

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    analyses = ['structural', 'regulatory'] if args.analysis == 'all' else [args.analysis]
    splits = ['train', 'val'] if args.split == 'all' else [args.split]
    models_to_process = MODELS if args.model == 'all' else [args.model]

    print("=" * 70)
    print("STEP 4: EVALUATING EMBEDDING QUALITY (GPU)")
    print("=" * 70)
    print(f"Device: {device}")

    all_results = []

    for analysis_type in analyses:
        for split in splits:
            print(f"\n{'='*70}")
            print(f"Processing: {analysis_type} - {split}")
            print("=" * 70)

            df, embeddings_dict = load_data(analysis_type, split)
            # Filter to requested models
            embeddings_dict = {k: v for k, v in embeddings_dict.items() if k in models_to_process}
            labels = df['category'].values

            print(f"Samples: {len(df)}")
            print(f"Categories: {np.unique(labels)}")
            print(f"Models: {list(embeddings_dict.keys())}")

            for model_name, embeddings in embeddings_dict.items():
                results = evaluate_model(model_name, embeddings, labels, analysis_type, split, device)
                all_results.append(results)

    # Save all results
    df_results = pd.DataFrame(all_results)
    output_file = OUTPUT_DIR / 'embedding_metrics.csv'
    df_results.to_csv(output_file, index=False)
    print(f"\n{'='*70}")
    print(f"Saved: {output_file}")

    # Print summary
    print(f"\n{'='*70}")
    print("SUMMARY")
    print("=" * 70)
    summary_cols = ['model', 'analysis', 'split', 'silhouette', 'separation', 'auroc_mean', 'auprc_mean']
    print(df_results[summary_cols].to_string(index=False))

    print("\n" + "=" * 70)
    print("STEP 4 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
