#!/usr/bin/env python3
"""
Step 2: Extract attention weights from models

Extract attention weights for sampled sequences from:
- seq (sequence only)
- struct (sequence + GENCODE only)
- full (sequence + GENCODE + ENCODE)
"""

import sys
sys.path.append('..')
sys.path.append('../2_train')

import numpy as np
import json
import torch
from pathlib import Path
from tqdm import tqdm
import argparse
import h5py

from model import annDNA
from config import get_model_paths, get_model_config, get_attention_paths, MODELS

# Available models
MODEL_LIST = ['seq', 'struct', 'full', 'distilled']


def load_model(model_name, device):
    """Load trained model"""
    print(f"Loading {model_name}...")

    model_config = get_model_config(model_name)
    model_paths = get_model_paths(model_name)

    if model_name == 'distilled':
        # Load distilled model
        sys.path.append(str(Path(__file__).parent.parent / '6_distillation'))
        from train_distilled import DistilledModel

        model = DistilledModel(
            vocab_size=model_config['vocab_size'],
            d_model=model_config['d_model'],
            nhead=model_config['nhead'],
            num_layers=model_config['num_layers'],
            max_seq_len=model_config['max_seq_len']
        )

        checkpoint_path = MODELS['distilled']['model_path']
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)

        print(f"  Loaded from {checkpoint_path}")
    else:
        model = annDNA(
            vocab_size=model_config['vocab_size'],
            d_model=model_config['d_model'],
            nhead=model_config['nhead'],
            num_layers=model_config['num_layers'],
            max_seq_len=model_config['max_seq_len']
        )

        checkpoint = torch.load(model_paths['best_model'], map_location=device, weights_only=False)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)

        print(f"  Loaded from {model_paths['best_model']}")

    model.to(device)
    model.eval()
    print(f"  Parameters: {model.get_num_params():,}")

    return model


def load_tokens_for_sample(sample, model_name):
    """Load tokens for a sample window"""
    # Distilled uses seq tokens
    token_model = 'seq' if model_name == 'distilled' else model_name
    model_paths = get_model_paths(token_model)
    tokens_path = model_paths['tokens_dir'] / f"{sample['chrom']}_tokens.npy"

    tokens = np.load(tokens_path, mmap_mode='r')
    window_tokens = tokens[sample['start']:sample['end']].copy()

    return window_tokens


def extract_attention_batch(model, tokens_list, device):
    """
    Extract attention weights from model for a batch

    Args:
        tokens_list: list of token arrays

    Returns:
        list of attention arrays, each (num_layers, num_heads, seq_len, seq_len)
    """
    # Stack tokens into batch
    batch_tokens = np.stack(tokens_list)
    input_ids = torch.LongTensor(batch_tokens).to(device)
    attention_mask = torch.ones_like(input_ids).float()

    # Get attention
    with torch.no_grad():
        attention = model.get_attention(input_ids, attention_mask)

    # attention shape: (num_layers, batch, num_heads, seq_len, seq_len)
    # Split by batch
    if isinstance(attention, torch.Tensor):
        attention = attention.cpu().numpy()

    results = []
    for i in range(len(tokens_list)):
        results.append(attention[:, i, :, :, :])

    return results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='all', choices=MODEL_LIST + ['all'],
                       help='Model to process (default: all)')
    parser.add_argument('--batch_size', type=int, default=4,
                       help='Batch size for extraction')
    parser.add_argument('--gpu', type=int, default=0)
    args = parser.parse_args()

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Setup paths
    attention_paths = get_attention_paths()
    output_dir = attention_paths['attention']
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load samples (use full samples)
    samples_file = attention_paths['root'] / 'samples' / 'full_samples.json'
    print(f"Loading samples from {samples_file}...")
    with open(samples_file) as f:
        samples_data = json.load(f)

    samples = samples_data['samples']
    print(f"Processing {len(samples)} samples")

    # Process each model
    models_to_process = MODEL_LIST if args.model == 'all' else [args.model]

    for model_name in models_to_process:
        print(f"\n{'='*60}")
        print(f"Model: {model_name}")
        print('='*60)

        model = load_model(model_name, device)
        model_config = get_model_config(model_name)

        # Output file
        output_file = output_dir / f'{model_name}_attention.h5'

        with h5py.File(output_file, 'w') as hf:
            # Store metadata
            hf.attrs['model'] = model_name
            hf.attrs['num_layers'] = model_config['num_layers']
            hf.attrs['num_heads'] = model_config['nhead']
            hf.attrs['num_samples'] = len(samples)

            # Process in batches
            n_batches = (len(samples) + args.batch_size - 1) // args.batch_size

            for batch_idx in tqdm(range(n_batches), desc=f"Extracting {model_name}"):
                start_idx = batch_idx * args.batch_size
                end_idx = min(start_idx + args.batch_size, len(samples))
                batch_samples = samples[start_idx:end_idx]

                # Load tokens for batch
                tokens_list = [load_tokens_for_sample(s, model_name) for s in batch_samples]

                # Extract attention for batch
                attention_list = extract_attention_batch(model, tokens_list, device)

                # Save each sample
                for i, (sample, attention) in enumerate(zip(batch_samples, attention_list)):
                    sample_idx = start_idx + i
                    grp = hf.create_group(f'sample_{sample_idx}')
                    grp.create_dataset('attention', data=attention, compression='gzip')
                    grp.attrs['chrom'] = sample['chrom']
                    grp.attrs['start'] = sample['start']
                    grp.attrs['end'] = sample['end']
                    grp.attrs['diversity'] = sample['diversity']

        print(f"Saved to {output_file}")

        # Cleanup
        del model
        torch.cuda.empty_cache()

    print(f"\n{'='*60}")
    print("DONE")
    print('='*60)


if __name__ == '__main__':
    main()
