#!/usr/bin/env python3
"""
Step 2: Extract Embeddings

Extract embeddings for REF and ALT sequences, compute L2 distances.
Models: seq, struct, full, grover, distilled
"""

import sys
sys.path.append('..')

import numpy as np
import torch
import json
from pathlib import Path
from tqdm import tqdm
import argparse
import importlib.util
import os

from transformers import AutoTokenizer, AutoModel

from config import get_model_paths, get_model_config, get_benchmark_paths

# Paths
BENCHMARK_PATHS = get_benchmark_paths()
TOKENS_DIR = BENCHMARK_PATHS['tokens']
EMBEDDINGS_DIR = BENCHMARK_PATHS['embeddings']
EMBEDDINGS_DIR.mkdir(exist_ok=True, parents=True)

# Datasets and Models
DATASETS = ['mendelian', 'complex', 'eqtl', 'clinvar']
MODELS = ['seq', 'struct', 'full', 'grover', 'distilled']
WINDOW_SIZES = [250, 500, 750]
VARIANT_POSITIONS = [0.25, 0.5, 0.75]


def load_anndna_model(model_name, device):
    """Load annDNA model"""
    config = get_model_config(model_name)
    paths = get_model_paths(model_name)

    # Load model class
    model_script = Path(__file__).parent.parent / '2_train' / 'model.py'
    spec = importlib.util.spec_from_file_location("model_module", model_script)
    model_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model_module)
    annDNA = model_module.annDNA

    # Load checkpoint
    checkpoint = torch.load(paths['best_model'], map_location='cpu')
    vocab_size = checkpoint.get('vocab_size', config.get('vocab_size', 268))

    model = annDNA(
        vocab_size=vocab_size,
        d_model=config['d_model'],
        nhead=config['nhead'],
        num_layers=config['num_layers'],
        max_seq_len=config['max_seq_len']
    )

    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)

    model.eval()
    model = model.to(device)

    print(f"Loaded {model_name}: d_model={config['d_model']}, layers={config['num_layers']}")
    return model, config


def load_grover_model(device):
    """Load GROVER model from HuggingFace"""
    config = get_model_config('grover')
    hf_model_name = config['hf_model_name']

    print(f"Loading GROVER from HuggingFace: {hf_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
    model = AutoModel.from_pretrained(hf_model_name, use_safetensors=True)
    model = model.to(device)
    model.eval()

    return model, tokenizer, config


def load_distilled_model(device):
    """Load distilled student model"""
    from config import MODELS as MODEL_CONFIGS

    distilled_config = MODEL_CONFIGS['distilled']
    checkpoint_path = distilled_config['model_path']

    # Load distilled model class
    sys.path.append(str(Path(__file__).parent.parent / '6_distillation'))
    from train_distilled import DistilledModel

    # Load vocab
    distilled_paths = get_model_paths('seq')
    with open(distilled_paths['vocab_file']) as f:
        vocab = json.load(f)

    model = DistilledModel(
        vocab_size=len(vocab),
        d_model=distilled_config['d_model'],
        nhead=distilled_config['nhead'],
        num_layers=distilled_config['num_layers'],
        max_seq_len=distilled_config['max_seq_len']
    )

    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model = model.to(device)

    print(f"Loaded Distilled: d_model={distilled_config['d_model']}, layers={distilled_config['num_layers']}")
    return model, distilled_config


def extract_embeddings_batch(model, tokens_batch, device):
    """Extract embeddings for a batch (annDNA)"""
    input_ids = torch.from_numpy(tokens_batch).long().to(device)
    attention_mask = torch.ones_like(input_ids)

    with torch.no_grad():
        base_model = model.module if hasattr(model, 'module') else model
        captured = []

        def hook(m, inp, out):
            captured.append(out.detach())

        handle = base_model.transformer.register_forward_hook(hook)
        try:
            _ = model(input_ids, attention_mask)
        finally:
            handle.remove()

        hidden_states = captured[0]

    # Mean pooling (excluding CLS/SEP)
    mask = torch.ones(hidden_states.size(0), hidden_states.size(1), device=device)
    mask[:, 0] = 0  # CLS
    mask[:, -1] = 0  # SEP

    mask = mask.unsqueeze(-1)
    embeddings = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

    return embeddings.cpu().numpy()


def extract_grover_embeddings(model, tokenizer, tokens_array, id_to_token, device, batch_size=64, desc="GROVER"):
    """Extract embeddings from GROVER model"""
    model.eval()
    all_embeddings = []

    # Convert token arrays to DNA sequences
    sequences = []
    for tokens in tokens_array:
        bases = []
        for token_id in tokens:
            token_str = id_to_token.get(token_id, 'N')
            base = token_str.split('_')[0] if '_' in token_str else token_str
            if base in ['A', 'T', 'G', 'C']:
                bases.append(base)
            elif base not in ['<PAD>', '<CLS>', '<SEP>', '<MASK>', '<UNK>']:
                bases.append('N')
        sequences.append(''.join(bases))

    # Extract embeddings
    with torch.no_grad():
        for i in tqdm(range(0, len(sequences), batch_size), desc=desc, leave=False):
            batch_seqs = sequences[i:i+batch_size]

            encoded = tokenizer(batch_seqs, padding=True, truncation=True,
                                max_length=512, return_tensors='pt')
            encoded = {k: v.to(device) for k, v in encoded.items()}

            outputs = model(**encoded)
            last_hidden = outputs.last_hidden_state
            attention_mask = encoded['attention_mask']

            batch_sz, seq_len, hidden_size = last_hidden.size()

            # Mean pooling (excluding CLS/SEP)
            special_tokens_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
            for j in range(batch_sz):
                special_tokens_mask[j, 0] = True
                valid_length = attention_mask[j].sum().item()
                if valid_length > 1:
                    special_tokens_mask[j, int(valid_length) - 1] = True

            token_mask = attention_mask.bool() & (~special_tokens_mask)
            token_mask_expanded = token_mask.unsqueeze(-1).expand(last_hidden.size()).float()

            sum_embeddings = torch.sum(last_hidden * token_mask_expanded, dim=1)
            sum_mask = torch.clamp(token_mask_expanded.sum(dim=1), min=1e-9)
            emb_mean = sum_embeddings / sum_mask

            all_embeddings.append(emb_mean.cpu().numpy())

    return np.vstack(all_embeddings)


def extract_distilled_embeddings_batch(model, tokens_batch, device):
    """Extract embeddings from distilled model"""
    input_ids = torch.from_numpy(tokens_batch).long().to(device)
    attention_mask = torch.ones_like(input_ids)

    with torch.no_grad():
        logits, hidden_states = model(input_ids, attention_mask)

    # Mean pooling (excluding CLS/SEP)
    mask = torch.ones(hidden_states.size(0), hidden_states.size(1), device=device)
    mask[:, 0] = 0  # CLS
    mask[:, -1] = 0  # SEP

    mask = mask.unsqueeze(-1)
    embeddings = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

    return embeddings.cpu().numpy()


def process_file(model, model_name, tokens_file, device, batch_size=64,
                 grover_tokenizer=None, id_to_token=None, is_distilled=False):
    """Process a single token file"""
    filename = tokens_file.stem

    # Parse filename: {model}_{dataset}_w{window}_{pos}_tokens
    parts = filename.replace('_tokens', '').split('_')

    window_size = None
    position = None
    dataset = None

    for part in parts:
        if part.startswith('w') and part[1:].isdigit():
            window_size = int(part[1:])
        elif part.startswith('pos'):
            pos_str = part[3:]
            if len(pos_str) == 2:
                position = float(pos_str) / 10
            elif len(pos_str) == 3:
                position = float(pos_str) / 100
        elif part in DATASETS:
            dataset = part

    if window_size is None or position is None or dataset is None:
        print(f"  Cannot parse: {filename}")
        return None

    print(f"\n  {filename}")
    print(f"    Dataset: {dataset}, Window: {window_size}, Position: {position}")

    # Load tokens
    data = np.load(tokens_file)
    ref_tokens = data['ref_tokens']
    alt_tokens = data['alt_tokens']
    labels = data['labels']

    n_samples = len(labels)
    print(f"    Samples: {n_samples:,}")

    # Extract embeddings
    if model_name == 'grover':
        d_model = 768
        ref_embeddings = extract_grover_embeddings(
            model, grover_tokenizer, ref_tokens, id_to_token, device, batch_size, "    REF"
        )
        alt_embeddings = extract_grover_embeddings(
            model, grover_tokenizer, alt_tokens, id_to_token, device, batch_size, "    ALT"
        )
    elif is_distilled:
        d_model = model.d_model if hasattr(model, 'd_model') else 768
        ref_embeddings = np.zeros((n_samples, d_model), dtype=np.float32)
        alt_embeddings = np.zeros((n_samples, d_model), dtype=np.float32)

        n_batches = (n_samples + batch_size - 1) // batch_size

        for i in tqdm(range(n_batches), desc="    REF", leave=False):
            start = i * batch_size
            end = min((i + 1) * batch_size, n_samples)
            ref_embeddings[start:end] = extract_distilled_embeddings_batch(model, ref_tokens[start:end], device)

        for i in tqdm(range(n_batches), desc="    ALT", leave=False):
            start = i * batch_size
            end = min((i + 1) * batch_size, n_samples)
            alt_embeddings[start:end] = extract_distilled_embeddings_batch(model, alt_tokens[start:end], device)
    else:
        d_model = model.d_model if hasattr(model, 'd_model') else 768
        ref_embeddings = np.zeros((n_samples, d_model), dtype=np.float32)
        alt_embeddings = np.zeros((n_samples, d_model), dtype=np.float32)

        n_batches = (n_samples + batch_size - 1) // batch_size

        for i in tqdm(range(n_batches), desc="    REF", leave=False):
            start = i * batch_size
            end = min((i + 1) * batch_size, n_samples)
            ref_embeddings[start:end] = extract_embeddings_batch(model, ref_tokens[start:end], device)

        for i in tqdm(range(n_batches), desc="    ALT", leave=False):
            start = i * batch_size
            end = min((i + 1) * batch_size, n_samples)
            alt_embeddings[start:end] = extract_embeddings_batch(model, alt_tokens[start:end], device)

    # Compute L2 distance
    diff_embeddings = alt_embeddings - ref_embeddings
    l2_distances = np.linalg.norm(diff_embeddings, axis=1)

    # Save
    pos_str = f"pos{position}".replace(".", "")
    output_prefix = f"{model_name}_{dataset}_w{window_size}_{pos_str}"

    np.savez_compressed(
        EMBEDDINGS_DIR / f"{output_prefix}_embeddings.npz",
        ref=ref_embeddings,
        alt=alt_embeddings,
        diff=diff_embeddings,
        l2_distance=l2_distances,
        labels=labels,
    )

    print(f"    Saved: {output_prefix}_embeddings.npz")
    print(f"    L2 distance - mean: {l2_distances.mean():.4f}, std: {l2_distances.std():.4f}")

    return {'l2_distances': l2_distances, 'labels': labels}


def main():
    parser = argparse.ArgumentParser(description='Step 2: Extract Embeddings')
    parser.add_argument('--model', default='all', choices=['all'] + MODELS)
    parser.add_argument('--dataset', default='all', choices=['all'] + DATASETS)
    parser.add_argument('--window', default='all')
    parser.add_argument('--position', default='all')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--gpu', default='0', help='GPU ID')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("=" * 70)
    print("STEP 2: EXTRACT EMBEDDINGS")
    print("=" * 70)
    print(f"Device: {device}")
    print(f"Tokens dir: {TOKENS_DIR}")
    print(f"Embeddings dir: {EMBEDDINGS_DIR}")

    models = MODELS if args.model == 'all' else [args.model]
    datasets = DATASETS if args.dataset == 'all' else [args.dataset]
    windows = WINDOW_SIZES if args.window == 'all' else [int(args.window)]
    positions = VARIANT_POSITIONS if args.position == 'all' else [float(args.position)]

    print(f"Models: {models}")
    print(f"Datasets: {datasets}")
    print(f"Windows: {windows}")
    print(f"Positions: {positions}")

    for model_name in models:
        print(f"\n{'='*70}")
        print(f"Loading Model: {model_name}")
        print("=" * 70)

        # Load model
        grover_tokenizer = None
        id_to_token = None
        is_distilled = False

        if model_name == 'grover':
            model, grover_tokenizer, config = load_grover_model(device)
            # Load vocab for token conversion (use seq)
            full_paths = get_model_paths('seq')
            with open(full_paths['vocab_file']) as f:
                vocab = json.load(f)
            id_to_token = {v: k for k, v in vocab.items()}
            token_model = 'seq'  # GROVER uses seq tokens
        elif model_name == 'distilled':
            model, config = load_distilled_model(device)
            token_model = 'seq'  # Distilled uses seq tokens
            is_distilled = True
        else:
            model, config = load_anndna_model(model_name, device)
            token_model = model_name

        # Find matching token files
        for dataset in datasets:
            for window in windows:
                for position in positions:
                    pos_str = f"pos{position}".replace(".", "")
                    token_file = TOKENS_DIR / f"{token_model}_{dataset}_w{window}_{pos_str}_tokens.npz"

                    if not token_file.exists():
                        print(f"\n  Not found: {token_file.name}")
                        continue

                    try:
                        process_file(
                            model, model_name, token_file, device, args.batch_size,
                            grover_tokenizer=grover_tokenizer, id_to_token=id_to_token,
                            is_distilled=is_distilled
                        )
                    except Exception as e:
                        print(f"  Error: {e}")
                        import traceback
                        traceback.print_exc()

        del model
        torch.cuda.empty_cache()

    print("\n" + "=" * 70)
    print("STEP 2 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
