#!/usr/bin/env python3
"""
Step 2: Extract Embeddings

Extract embeddings from 4 models:
- grover
- seq
- struct
- full
"""

import sys
sys.path.append('..')
sys.path.append('../2_train')

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
import argparse
import json

from transformers import AutoTokenizer, AutoModel
from model import annDNA
import config
from config import get_model_paths, get_model_config

# Paths
# INPUT_DIR = Path('')  # results/3_embedding path
# OUTPUT_DIR = Path('')  # results/3_embedding/embeddings path
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = 32


def load_grover_model(device):
    """Load GROVER model from HuggingFace"""
    print("Loading GROVER...")
    tokenizer = AutoTokenizer.from_pretrained("PoetschLab/GROVER")
    model = AutoModel.from_pretrained(
        "PoetschLab/GROVER",
        use_safetensors=True
    ).to(device)
    model.eval()
    return model, tokenizer


def load_anndna_model(model_name: str, device):
    """Load annDNA model (seq, struct, full)"""
    print(f"Loading {model_name}...")

    model_config = get_model_config(model_name)
    model_paths = get_model_paths(model_name)

    # Load vocabulary
    with open(model_paths['vocab_file'], 'r') as f:
        vocab = json.load(f)

    # Create model
    model = annDNA(
        vocab_size=model_config['vocab_size'],
        d_model=model_config['d_model'],
        nhead=model_config['nhead'],
        num_layers=model_config['num_layers'],
        max_seq_len=model_config['max_seq_len']
    )

    # Load checkpoint
    checkpoint = torch.load(model_paths['best_model'], map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)

    model.to(device)
    model.eval()

    return model, vocab


def extract_embeddings_grover(model, tokenizer, sequences, device):
    """Extract embeddings from GROVER"""
    embeddings = []

    # Process in batches
    n_batches = (len(sequences) + BATCH_SIZE - 1) // BATCH_SIZE

    for i in tqdm(range(n_batches), desc="GROVER"):
        start_idx = i * BATCH_SIZE
        end_idx = min((i + 1) * BATCH_SIZE, len(sequences))
        batch_sequences = sequences[start_idx:end_idx]

        # Tokenize
        encoded = tokenizer(
            batch_sequences,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )

        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)

        with torch.no_grad():
            # Get embeddings
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

            # Use last hidden states
            hidden_states = outputs.hidden_states[-1]

            # Mean pooling (excluding special tokens)
            batch_size, seq_len, hidden_size = hidden_states.size()

            special_tokens_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
            for j in range(batch_size):
                special_tokens_mask[j, 0] = True  # CLS
                valid_length = attention_mask[j].sum().item()
                if valid_length > 1:
                    special_tokens_mask[j, valid_length - 1] = True  # SEP

            token_mask = attention_mask.bool() & (~special_tokens_mask)
            token_mask_expanded = token_mask.unsqueeze(-1).expand(hidden_states.size()).float()

            sum_embeddings = torch.sum(hidden_states * token_mask_expanded, dim=1)
            sum_mask = torch.clamp(token_mask_expanded.sum(dim=1), min=1e-9)
            mean_pooled = sum_embeddings / sum_mask

            embeddings.append(mean_pooled.cpu().numpy())

    return np.vstack(embeddings)


def extract_embeddings_anndna(model, tokens_list, vocab, device):
    """Extract embeddings from annDNA models"""
    embeddings = []

    # Pad tokens to same length in each batch
    for i in tqdm(range(0, len(tokens_list), BATCH_SIZE), desc="annDNA"):
        batch_tokens_raw = tokens_list[i:i+BATCH_SIZE]

        # Find max length in batch
        max_len = max(len(t) for t in batch_tokens_raw)

        # Pad tokens with PAD token (0)
        batch_tokens = []
        for tokens in batch_tokens_raw:
            padded = np.zeros(max_len, dtype=np.int64)
            padded[:len(tokens)] = tokens
            batch_tokens.append(padded)

        batch_tensor = torch.LongTensor(np.array(batch_tokens)).to(device)
        attention_mask = (batch_tensor != 0).float()

        with torch.no_grad():
            # Get embeddings
            positions = torch.arange(batch_tensor.size(1), device=device).unsqueeze(0).expand_as(batch_tensor)
            embedded = model.token_embedding(batch_tensor) + model.pos_embedding(positions)
            embedded = model.layer_norm(embedded)
            embedded = model.dropout(embedded)

            # Pass through transformer
            hidden_states = model.transformer(
                embedded,
                src_key_padding_mask=(attention_mask == 0)
            )

            # Mean pooling (excluding padding)
            # Note: If tokens have CLS/SEP, exclude them. Otherwise use all tokens.
            # For simplicity, use all non-padding tokens
            mask_for_pooling = attention_mask.unsqueeze(-1)
            masked_hidden = hidden_states * mask_for_pooling

            summed = torch.sum(masked_hidden, dim=1)
            count = torch.sum(attention_mask, dim=1, keepdim=True)
            mean_pooled = summed / (count + 1e-10)

            embeddings.append(mean_pooled.cpu().numpy())

    return np.vstack(embeddings)


def load_sequences_for_model(df, model_name, vocab=None, vocab_size=None, tokens_dir=None):
    """Load sequences or tokens for a specific model"""

    if model_name == 'grover':
        # Load DNA sequences
        from pyfaidx import Fasta
        reference = Fasta('',  # hg38/GRCh38.primary_assembly.genome.fasta path
                         as_raw=True, sequence_always_upper=True)

        sequences = []
        for _, row in df.iterrows():
            seq = str(reference[row['chrom']][row['start']:row['end']]).upper()
            sequences.append(seq)

        return sequences

    else:
        # Load tokens from model-specific directory
        tokens_list = []

        for idx, row in df.iterrows():
            token_file = tokens_dir / f'{row["chrom"]}_tokens.npy'
            tokens = np.load(token_file, mmap_mode='r')
            window_tokens = tokens[row['start']:row['end']].copy()

            # Validate token range
            invalid_mask = (window_tokens < 0) | (window_tokens >= vocab_size)
            if invalid_mask.any():
                print(f"Warning: Found {invalid_mask.sum()} invalid tokens in sample {idx}")
                print(f"  Range: [{window_tokens.min()}, {window_tokens.max()}], vocab_size: {vocab_size}")
                print(f"  Location: {row['chrom']}:{row['start']}-{row['end']}")
                print(f"  Tokens dir: {tokens_dir}")
                # Clip to valid range
                window_tokens = np.clip(window_tokens, 0, vocab_size - 1)

            tokens_list.append(window_tokens)

        return tokens_list


def load_distilled_model(device):
    """Load distilled student model"""
    print("Loading Student model...")
    sys.path.append(str(Path(__file__).parent.parent / '6_distillation'))
    from train_distilled import DistilledModel
    from distill_config import STUDENT_CONFIG

    model_paths = get_model_paths('distilled')

    # Load vocabulary (same as seq)
    with open(model_paths['vocab_file'], 'r') as f:
        vocab = json.load(f)

    model = DistilledModel(
        vocab_size=len(vocab),
        d_model=STUDENT_CONFIG['d_model'],
        nhead=STUDENT_CONFIG['nhead'],
        num_layers=STUDENT_CONFIG['num_layers'],
        max_seq_len=STUDENT_CONFIG['max_seq_len']
    )

    checkpoint = torch.load(model_paths['best_model'], map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    return model, vocab


def process_model(model_name, df, device):
    """Extract embeddings for one model"""
    print(f"\n{'='*80}")
    print(f"Processing: {model_name}")
    print(f"{'='*80}")

    # Load model
    if model_name == 'grover':
        model, tokenizer = load_grover_model(device)
    elif model_name == 'distilled':
        model, vocab = load_distilled_model(device)
        model_config = config.MODELS['distilled']
        model_paths = get_model_paths('distilled')
        vocab_size = len(vocab)
        tokens_dir = model_paths['tokens_dir']
    else:
        model, vocab = load_anndna_model(model_name, device)
        model_config = get_model_config(model_name)
        model_paths = get_model_paths(model_name)
        vocab_size = model_config['vocab_size']
        tokens_dir = model_paths['tokens_dir']

    # Load sequences/tokens
    print("Loading sequences/tokens...")
    if model_name == 'grover':
        data = load_sequences_for_model(df, model_name)
    else:
        data = load_sequences_for_model(df, model_name, vocab, vocab_size, tokens_dir)

    # Extract embeddings
    print(f"Extracting embeddings for {len(df)} samples...")
    if model_name == 'grover':
        embeddings = extract_embeddings_grover(model, tokenizer, data, device)
    else:
        embeddings = extract_embeddings_anndna(model, data, vocab, device)

    print(f"Embeddings shape: {embeddings.shape}")

    # Cleanup
    del model
    torch.cuda.empty_cache()

    return embeddings


def main():
    parser = argparse.ArgumentParser(description='Step 2: Extract Embeddings')
    parser.add_argument('--analysis', type=str, required=True,
                       choices=['structural', 'regulatory', 'all'],
                       help='Which analysis to process')
    parser.add_argument('--split', type=str, required=True,
                       choices=['train', 'val', 'all'],
                       help='Which chromosome split to use')
    parser.add_argument('--model', type=str, default='all',
                       choices=['grover', 'seq', 'struct', 'full', 'distilled', 'all'],
                       help='Which model to process')
    parser.add_argument('--gpu', type=int, default=0, help='GPU device')
    args = parser.parse_args()

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')

    # Determine which analyses and splits to process
    analyses = ['structural', 'regulatory'] if args.analysis == 'all' else [args.analysis]
    splits = ['train', 'val'] if args.split == 'all' else [args.split]
    models = ['grover', 'seq', 'struct', 'full', 'distilled'] if args.model == 'all' else [args.model]

    print("="*80)
    print("STEP 2: EXTRACTING EMBEDDINGS")
    print("="*80)
    print(f"Analyses: {analyses}")
    print(f"Splits: {splits}")
    print(f"Models: {models}")
    print(f"Device: {device}")

    for analysis_type in analyses:
        for split in splits:
            print("\n" + "="*80)
            print(f"Processing: {analysis_type} - {split}")
            print("="*80)

            # Load samples
            sample_file = INPUT_DIR / f'{analysis_type}_{split}_samples.tsv'
            df = pd.read_csv(sample_file, sep='\t')
            print(f"\nLoaded {len(df)} samples from {sample_file}")
            print(df['category'].value_counts())

            # Process models
            for model_name in models:
                embeddings = process_model(model_name, df, device)

                # Save embeddings
                output_file = OUTPUT_DIR / f'{model_name}_{analysis_type}_{split}_embeddings.npy'
                np.save(output_file, embeddings)
                print(f"✓ Saved: {output_file}")

    print("\n" + "="*80)
    print("STEP 2 COMPLETE")
    print("="*80)


if __name__ == '__main__':
    main()
