#!/usr/bin/env python3
"""
Step 3: Attention Density Analysis

Simple analysis of attention density per layer × head for each functional category.
Uses full vocabulary as ground truth for all models.

Output: CSV with (model, layer, head, category, density)
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import json
import h5py
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import argparse

from config import get_model_paths, get_attention_paths

# Functional categories
# Structural: mutually exclusive
STRUCTURAL_CATEGORIES = ['CDS', 'UTR', 'intron', 'intergenic']
# Regulatory: can overlap with structural
REGULATORY_CATEGORIES = ['promoter', 'enhancer', 'CTCF', 'H3K4me3']
ALL_CATEGORIES = STRUCTURAL_CATEGORIES + REGULATORY_CATEGORIES


def get_structural_category(token_name):
    """Get structural category (mutually exclusive)"""
    if token_name in ['<PAD>', 'N']:
        return None

    parts = token_name.split('_')
    base = parts[0]
    if base not in 'ATGC':
        return None

    annotations = set(parts[1:]) if len(parts) > 1 else set()

    # Hierarchy: CDS > UTR > intron > intergenic
    if 'CDS' in annotations:
        return 'CDS'
    if 'UTR' in annotations:
        return 'UTR'
    if 'gene' in annotations or 'transcript' in annotations:
        return 'intron'

    return 'intergenic'


def get_regulatory_categories(token_name):
    """Get regulatory categories (can have multiple)"""
    if token_name in ['<PAD>', 'N']:
        return []

    parts = token_name.split('_')
    base = parts[0]
    if base not in 'ATGC':
        return []

    annotations = set(parts[1:]) if len(parts) > 1 else set()

    categories = []
    if 'prom' in annotations:
        categories.append('promoter')
    if 'enhP' in annotations or 'enhD' in annotations:
        categories.append('enhancer')
    if 'CTCF' in annotations:
        categories.append('CTCF')
    if 'K4m3' in annotations:
        categories.append('H3K4me3')

    return categories


def get_all_categories(token_name):
    """Get all categories for a token (structural + regulatory)"""
    categories = []

    structural = get_structural_category(token_name)
    if structural:
        categories.append(structural)

    categories.extend(get_regulatory_categories(token_name))

    return categories


def build_full_category_mapping():
    """Build category mapping using full vocabulary (ground truth)"""
    model_paths = get_model_paths('full')

    with open(model_paths['vocab_file']) as f:
        vocab = json.load(f)

    id_to_token = {v: k for k, v in vocab.items()}
    id_to_categories = {tid: get_all_categories(tname) for tid, tname in id_to_token.items()}

    return id_to_categories, id_to_token


def load_full_tokens_for_sample(sample):
    """Load full tokens for a sample"""
    model_paths = get_model_paths('full')
    tokens_path = model_paths['tokens_dir'] / f"{sample['chrom']}_tokens.npy"
    tokens = np.load(tokens_path, mmap_mode='r')
    return tokens[sample['start']:sample['end']].copy()


def compute_layer_head_density(attention, position_categories_list):
    """
    Compute attention density for each layer × head × category.

    Args:
        attention: [num_layers, num_heads, seq_len, seq_len]
        position_categories_list: list of lists, each position's categories

    Returns:
        list of dicts with (layer, head, category, density)
    """
    num_layers, num_heads, seq_len, _ = attention.shape
    results = []

    for layer in range(num_layers):
        for head in range(num_heads):
            # Attention received by each position (sum over source positions)
            attn_received = attention[layer, head, :, :].sum(axis=0)

            # Aggregate by category (multi-label)
            cat_attn = defaultdict(float)
            cat_count = defaultdict(int)

            for pos, cats in enumerate(position_categories_list):
                for cat in cats:
                    if cat in ALL_CATEGORIES:
                        cat_attn[cat] += attn_received[pos]
                        cat_count[cat] += 1

            # Compute density for each category
            for cat in ALL_CATEGORIES:
                if cat_count[cat] > 0:
                    density = cat_attn[cat] / cat_count[cat]
                else:
                    density = np.nan

                results.append({
                    'layer': layer,
                    'head': head,
                    'category': cat,
                    'density': density,
                    'count': cat_count[cat]
                })

    return results


def analyze_model(model_name, samples, full_id_to_categories, attention_file):
    """Analyze attention patterns for one model"""
    print(f"\nAnalyzing {model_name}...")

    with h5py.File(attention_file, 'r') as hf:
        num_samples = hf.attrs['num_samples']
        num_layers = hf.attrs['num_layers']
        num_heads = hf.attrs['num_heads']

    print(f"  Layers: {num_layers}, Heads: {num_heads}")
    n_samples = min(num_samples, len(samples))

    all_results = []

    for i in tqdm(range(n_samples), desc=f"  Processing"):
        sample = samples[i]

        # Load attention
        with h5py.File(attention_file, 'r') as hf:
            grp = hf[f'sample_{i}']
            attention = grp['attention'][:]

        # Get position categories from full tokens (ground truth)
        full_tokens = load_full_tokens_for_sample(sample)
        position_categories_list = [full_id_to_categories.get(int(t), []) for t in full_tokens]

        # Compute layer × head × category density
        sample_results = compute_layer_head_density(attention, position_categories_list)

        for r in sample_results:
            r['model'] = model_name
            r['sample_idx'] = i
            all_results.append(r)

    return pd.DataFrame(all_results)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--models', default='seq,struct,full,distilled',
                       help='Comma-separated model names')
    args = parser.parse_args()

    attention_paths = get_attention_paths()
    output_dir = attention_paths['results']
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load full category mapping (ground truth)
    print("Loading full vocabulary as ground truth...")
    full_id_to_categories, _ = build_full_category_mapping()

    # Load samples
    samples_file = attention_paths['root'] / 'samples' / 'full_samples.json'
    with open(samples_file) as f:
        samples_data = json.load(f)
    samples = samples_data['samples']
    print(f"Loaded {len(samples)} samples")

    models = [m.strip() for m in args.models.split(',')]
    all_results = []

    for model_name in models:
        attention_file = attention_paths['attention'] / f'{model_name}_attention.h5'

        if not attention_file.exists():
            print(f"Warning: {attention_file} not found, skipping...")
            continue

        df = analyze_model(model_name, samples, full_id_to_categories, attention_file)
        all_results.append(df)

    # Combine results
    df_all = pd.concat(all_results, ignore_index=True)

    # Save detailed results
    output_file = output_dir / 'attention_density_layer_head.csv'
    df_all.to_csv(output_file, index=False)
    print(f"\nSaved: {output_file}")

    # Aggregate: mean density per model × layer × head × category
    agg_df = df_all.groupby(['model', 'layer', 'head', 'category']).agg({
        'density': ['mean', 'std'],
        'count': 'mean'
    }).reset_index()
    agg_df.columns = ['model', 'layer', 'head', 'category', 'density_mean', 'density_std', 'count_mean']

    agg_file = output_dir / 'attention_density_aggregated.csv'
    agg_df.to_csv(agg_file, index=False)
    print(f"Saved: {agg_file}")

    # Summary: mean density per model × category
    summary_df = df_all.groupby(['model', 'category']).agg({
        'density': ['mean', 'std']
    }).reset_index()
    summary_df.columns = ['model', 'category', 'density_mean', 'density_std']

    summary_file = output_dir / 'attention_density_summary.csv'
    summary_df.to_csv(summary_file, index=False)
    print(f"Saved: {summary_file}")

    # Print summary
    print("\n" + "=" * 70)
    print("SUMMARY: Mean Attention Density by Model × Category")
    print("=" * 70)

    for model in models:
        model_df = summary_df[summary_df['model'] == model]
        if model_df.empty:
            continue

        print(f"\n{model}:")
        for _, row in model_df.iterrows():
            print(f"  {row['category']:15s}: {row['density_mean']:.4f} ± {row['density_std']:.4f}")

    # Layer-wise summary
    print("\n" + "=" * 70)
    print("LAYER-WISE SUMMARY (averaged across heads)")
    print("=" * 70)

    layer_summary = df_all.groupby(['model', 'layer', 'category'])['density'].mean().reset_index()

    for model in models:
        model_df = layer_summary[layer_summary['model'] == model]
        if model_df.empty:
            continue

        n_layers = model_df['layer'].max() + 1
        print(f"\n{model} ({n_layers} layers):")

        # Show key categories
        key_cats = ['CDS', 'UTR', 'prom', 'enhP', 'enhD']
        for cat in key_cats:
            cat_df = model_df[model_df['category'] == cat]
            if not cat_df.empty:
                densities = cat_df.sort_values('layer')['density'].values
                density_str = ' '.join(f'{d:.3f}' for d in densities)
                print(f"  {cat:8s}: [{density_str}]")

    print(f"\nResults saved to {output_dir}")


if __name__ == '__main__':
    main()
