#!/usr/bin/env python3
"""
Step 3: Evaluate Benchmark Performance

Evaluate models using:
1. L2 distance-based classification
2. Linear probing with rich features (PyTorch GPU)
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from pathlib import Path
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
import os

from config import get_benchmark_paths

# Paths
BENCHMARK_PATHS = get_benchmark_paths()
EMBEDDINGS_DIR = BENCHMARK_PATHS['embeddings']
RESULTS_DIR = BENCHMARK_PATHS['results']
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

# Datasets and Models
DATASETS = ['mendelian', 'complex', 'eqtl', 'clinvar']
MODELS = ['seq', 'struct', 'full', 'grover', 'distilled']
WINDOW_SIZES = [250, 500, 750]
VARIANT_POSITIONS = [0.25, 0.5, 0.75]


class LogisticRegressionTorch(nn.Module):
    """Simple logistic regression in PyTorch"""
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x)


def train_logistic_regression_gpu(X_train, y_train, X_test, device, lr=1.0, epochs=50):
    """Train logistic regression on GPU with class weighting"""
    input_dim = X_train.shape[1]

    # Compute class weights for balanced training
    n_pos = y_train.sum()
    n_neg = len(y_train) - n_pos
    weight_pos = len(y_train) / (2 * n_pos) if n_pos > 0 else 1.0
    weight_neg = len(y_train) / (2 * n_neg) if n_neg > 0 else 1.0

    # Convert to tensors
    X_train_t = torch.FloatTensor(X_train).to(device)
    y_train_t = torch.FloatTensor(y_train).to(device)
    X_test_t = torch.FloatTensor(X_test).to(device)

    sample_weights = torch.where(y_train_t == 1, weight_pos, weight_neg).to(device)

    model = LogisticRegressionTorch(input_dim).to(device)
    optimizer = optim.LBFGS(model.parameters(), lr=lr, max_iter=20)

    def closure():
        optimizer.zero_grad()
        outputs = model(X_train_t).squeeze()
        loss = nn.functional.binary_cross_entropy_with_logits(
            outputs, y_train_t, weight=sample_weights
        )
        l2_reg = 0.5 * sum(p.pow(2).sum() for p in model.parameters())
        loss = loss + l2_reg / len(y_train)
        loss.backward()
        return loss

    for _ in range(epochs):
        optimizer.step(closure)

    model.eval()
    with torch.no_grad():
        logits = model(X_test_t).squeeze()
        probs = torch.sigmoid(logits).cpu().numpy()

    return probs


def compute_l2_metrics(l2_distances, labels):
    """Compute metrics using L2 distance as score"""
    if len(np.unique(labels)) < 2:
        return np.nan, np.nan
    auroc = roc_auc_score(labels, l2_distances)
    auprc = average_precision_score(labels, l2_distances)
    return auroc, auprc


def compute_lp_metrics(ref_emb, alt_emb, labels, device, n_seeds=3):
    """Compute metrics using linear probing with rich features (GPU)

    Features: [ref, alt, diff, ref*alt, |diff|]
    5-fold CV × 3 seeds
    """
    diff_emb = alt_emb - ref_emb
    features = np.concatenate([
        ref_emb,
        alt_emb,
        diff_emb,
        ref_emb * alt_emb,
        np.abs(diff_emb),
    ], axis=1)

    aurocs = []
    auprcs = []

    seeds = [42, 123, 456][:n_seeds]

    for seed in seeds:
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)

        fold_aurocs = []
        fold_auprcs = []

        for train_idx, test_idx in skf.split(features, labels):
            X_train, X_test = features[train_idx], features[test_idx]
            y_train, y_test = labels[train_idx], labels[test_idx]

            scaler = StandardScaler()
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)

            y_prob = train_logistic_regression_gpu(
                X_train, y_train, X_test, device, lr=1.0, epochs=50
            )

            try:
                auroc = roc_auc_score(y_test, y_prob)
                auprc = average_precision_score(y_test, y_prob)
                fold_aurocs.append(auroc)
                fold_auprcs.append(auprc)
            except Exception:
                continue

        if fold_aurocs:
            aurocs.append(np.mean(fold_aurocs))
            auprcs.append(np.mean(fold_auprcs))

    if not aurocs:
        return {
            'auroc_mean': np.nan,
            'auroc_std': np.nan,
            'auprc_mean': np.nan,
            'auprc_std': np.nan,
        }

    return {
        'auroc_mean': np.mean(aurocs),
        'auroc_std': np.std(aurocs),
        'auprc_mean': np.mean(auprcs),
        'auprc_std': np.std(auprcs),
    }


def evaluate_file(model_name, dataset, window_size, var_position, device):
    """Evaluate a single embedding file"""
    pos_str = f"pos{var_position}".replace(".", "")
    emb_file = EMBEDDINGS_DIR / f"{model_name}_{dataset}_w{window_size}_{pos_str}_embeddings.npz"

    if not emb_file.exists():
        return None

    data = np.load(emb_file)
    ref_emb = data['ref']
    alt_emb = data['alt']
    l2_distances = data['l2_distance']
    labels = data['labels']

    n_samples = len(labels)
    n_pos = np.sum(labels == 1)
    n_neg = np.sum(labels == 0)

    print(f"    Samples: {n_samples:,} (pos={n_pos:,}, neg={n_neg:,})")

    # L2-based evaluation
    l2_auroc, l2_auprc = compute_l2_metrics(l2_distances, labels)
    print(f"    L2-AUROC: {l2_auroc:.4f}, L2-AUPRC: {l2_auprc:.4f}")

    # Linear probing
    print(f"    Running linear probing (5-fold, 3 seeds)...")
    lp_results = compute_lp_metrics(ref_emb, alt_emb, labels, device)
    print(f"    LP-AUROC: {lp_results['auroc_mean']:.4f} +/- {lp_results['auroc_std']:.4f}")

    return {
        'model': model_name,
        'dataset': dataset,
        'window_size': window_size,
        'var_position': var_position,
        'n_samples': n_samples,
        'n_pos': n_pos,
        'n_neg': n_neg,
        'l2_auroc': l2_auroc,
        'l2_auprc': l2_auprc,
        'lp_auroc': lp_results['auroc_mean'],
        'lp_auroc_std': lp_results['auroc_std'],
        'lp_auprc': lp_results['auprc_mean'],
        'lp_auprc_std': lp_results['auprc_std'],
    }


def main():
    parser = argparse.ArgumentParser(description='Step 3: Evaluate Performance')
    parser.add_argument('--model', default='all', choices=['all'] + MODELS)
    parser.add_argument('--dataset', default='all', choices=['all'] + DATASETS)
    parser.add_argument('--window', default='all')
    parser.add_argument('--position', default='all')
    parser.add_argument('--gpu', default='0', help='GPU ID')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("=" * 70)
    print("STEP 3: EVALUATE PERFORMANCE")
    print("=" * 70)
    print(f"Device: {device}")
    print(f"Embeddings dir: {EMBEDDINGS_DIR}")
    print(f"Results dir: {RESULTS_DIR}")

    models = MODELS if args.model == 'all' else [args.model]
    datasets = DATASETS if args.dataset == 'all' else [args.dataset]
    windows = WINDOW_SIZES if args.window == 'all' else [int(args.window)]
    positions = VARIANT_POSITIONS if args.position == 'all' else [float(args.position)]

    print(f"Models: {models}")
    print(f"Datasets: {datasets}")
    print(f"Windows: {windows}")
    print(f"Positions: {positions}")

    results = []

    for model_name in models:
        for dataset in datasets:
            for window in windows:
                for position in positions:
                    print(f"\n  {model_name} / {dataset} / w{window} / pos={position}")

                    result = evaluate_file(model_name, dataset, window, position, device)
                    if result:
                        results.append(result)

    if not results:
        print("\nNo results to save.")
        return

    # Save results
    df = pd.DataFrame(results)
    output_file = RESULTS_DIR / 'benchmark_results.csv'
    df.to_csv(output_file, index=False)
    print(f"\nResults saved to: {output_file}")

    # Summary tables
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)

    # By dataset (default: window=500, pos=0.5)
    summary_df = df[(df['window_size'] == 500) & (df['var_position'] == 0.5)]
    if not summary_df.empty:
        print("\n[LP-AUROC by Dataset (w=500, pos=0.5)]")
        pivot = summary_df.pivot(index='dataset', columns='model', values='lp_auroc')
        col_order = [c for c in MODELS if c in pivot.columns]
        if col_order:
            pivot = pivot[col_order]
        print(pivot.round(4).to_string())

    # By model (averaged)
    print("\n[LP-AUROC by Model (averaged across all)]")
    model_avg = df.groupby('model')['lp_auroc'].mean()
    print(model_avg.round(4).to_string())

    # Window ablation
    print("\n[LP-AUROC by Window Size]")
    window_pivot = df.groupby(['model', 'window_size'])['lp_auroc'].mean().unstack()
    print(window_pivot.round(4).to_string())

    # Position ablation
    print("\n[LP-AUROC by Variant Position]")
    pos_pivot = df.groupby(['model', 'var_position'])['lp_auroc'].mean().unstack()
    print(pos_pivot.round(4).to_string())

    print("\n" + "=" * 70)
    print("STEP 3 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
