#!/usr/bin/env python3
"""
Step 1: Sample diverse sequences from tokenized data

Finds genomic windows containing multiple functional regions
(CDS, UTR, intron, promoter, enhancer, etc.) directly from token data.

Uses all chromosomes except chr22 (test) and chrY.
"""

import sys
sys.path.append('..')

import numpy as np
import json
from collections import defaultdict
from tqdm import tqdm
import argparse

from config import get_model_paths, get_attention_paths

# Chromosomes to use (exclude chr22=test, chrY)
CHROMOSOMES = [f'chr{i}' for i in range(1, 22)] + ['chrX']

# Categories for diversity scoring
# Structural (mutually exclusive) + Regulatory (can overlap)
STRUCTURAL_CATEGORIES = ['CDS', 'UTR', 'intron', 'intergenic']
REGULATORY_CATEGORIES = ['promoter', 'enhancer', 'CTCF', 'H3K4me3']
KEY_CATEGORIES = STRUCTURAL_CATEGORIES + REGULATORY_CATEGORIES

# All available models
MODELS = ['seq', 'struct', 'full']


def get_structural_category(token_name):
    """Classify token into structural category (mutually exclusive)"""
    if token_name in ['<PAD>', 'N']:
        return None

    parts = token_name.split('_')
    base = parts[0]
    if base not in 'ATGC':
        return None

    annotations = set(parts[1:]) if len(parts) > 1 else set()

    # Structural hierarchy: CDS > UTR > intron > intergenic
    if 'CDS' in annotations:
        return 'CDS'
    if 'UTR' in annotations:
        return 'UTR'
    if 'gene' in annotations or 'transcript' in annotations:
        return 'intron'

    return 'intergenic'


def get_regulatory_categories(token_name):
    """Get regulatory categories for a token (can have multiple)"""
    if token_name in ['<PAD>', 'N']:
        return []

    parts = token_name.split('_')
    base = parts[0]
    if base not in 'ATGC':
        return []

    annotations = set(parts[1:]) if len(parts) > 1 else set()

    categories = []
    if 'prom' in annotations:
        categories.append('promoter')
    if 'enhP' in annotations or 'enhD' in annotations:
        categories.append('enhancer')
    if 'CTCF' in annotations:
        categories.append('CTCF')
    if 'K4m3' in annotations:
        categories.append('H3K4me3')

    return categories


def get_all_categories(token_name):
    """Get all categories for a token (structural + regulatory)"""
    categories = []

    structural = get_structural_category(token_name)
    if structural:
        categories.append(structural)

    categories.extend(get_regulatory_categories(token_name))

    return categories


def build_category_mapping(vocab_path):
    """Build token_id -> categories mapping"""
    with open(vocab_path) as f:
        vocab = json.load(f)

    id_to_token = {v: k for k, v in vocab.items()}
    id_to_categories = {tid: get_all_categories(tname) for tid, tname in id_to_token.items()}

    return vocab, id_to_token, id_to_categories


def find_diverse_windows(tokens, id_to_categories, window_size=1000, stride=10000,
                         min_diversity=4):
    """
    Find windows with diverse functional regions

    Args:
        tokens: numpy array of token IDs
        id_to_categories: mapping from token ID to list of categories
        window_size: window size in bp
        stride: step between windows
        min_diversity: minimum number of categories required

    Returns:
        list of diverse window info
    """
    # Different thresholds for structural vs regulatory
    STRUCTURAL_MIN = 50   # structural categories need more coverage
    REGULATORY_MIN = 10   # regulatory can be narrow peaks
    RARE_MIN = 1          # CTCF, H3K4me3 are very rare

    windows = []

    for start in range(0, len(tokens) - window_size, stride):
        window = tokens[start:start + window_size]

        # Skip if too many N
        n_count = np.sum(window == 1)
        if n_count > window_size * 0.1:
            continue

        # Count categories (multi-label)
        cat_counts = defaultdict(int)
        for tok_id in window:
            cats = id_to_categories.get(tok_id, [])
            for cat in cats:
                cat_counts[cat] += 1

        # Count diversity with different thresholds
        present = []
        for cat in STRUCTURAL_CATEGORIES:
            if cat_counts[cat] >= STRUCTURAL_MIN:
                present.append(cat)
        for cat in REGULATORY_CATEGORIES:
            if cat in ['CTCF', 'H3K4me3']:
                threshold = RARE_MIN
            else:
                threshold = REGULATORY_MIN
            if cat_counts[cat] >= threshold:
                present.append(cat)

        diversity = len(present)

        if diversity >= min_diversity:
            windows.append({
                'start': start,
                'end': start + window_size,
                'diversity': diversity,
                'categories': dict(cat_counts),
                'present': present
            })

    return windows


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='full', choices=MODELS + ['all'],
                       help='Model for vocab (full recommended)')
    parser.add_argument('--window_size', type=int, default=1000)
    parser.add_argument('--stride', type=int, default=10000)
    parser.add_argument('--min_diversity', type=int, default=4)
    parser.add_argument('--max_samples', type=int, default=300)
    args = parser.parse_args()

    # Use full for 'all' option (most complete vocab)
    model_name = 'full' if args.model == 'all' else args.model

    # Setup paths
    model_paths = get_model_paths(model_name)
    attention_paths = get_attention_paths()
    output_dir = attention_paths['root'] / 'samples'
    output_dir.mkdir(parents=True, exist_ok=True)

    # Build category mapping
    print(f"Loading vocabulary for {model_name}...")
    vocab, id_to_token, id_to_categories = build_category_mapping(model_paths['vocab_file'])
    print(f"Vocabulary size: {len(vocab)}")
    print(f"Categories: {KEY_CATEGORIES}")

    all_windows = []

    for chrom in tqdm(CHROMOSOMES, desc="Scanning chromosomes"):
        tokens_path = model_paths['tokens_dir'] / f'{chrom}_tokens.npy'
        if not tokens_path.exists():
            continue

        tokens = np.load(tokens_path)

        windows = find_diverse_windows(
            tokens, id_to_categories,
            window_size=args.window_size,
            stride=args.stride,
            min_diversity=args.min_diversity,
        )

        for w in windows:
            w['chrom'] = chrom

        all_windows.extend(windows)

    # Sort by diversity, with bonus for rare categories (CTCF, H3K4me3)
    def sort_key(w):
        rare_bonus = sum(1 for cat in ['CTCF', 'H3K4me3'] if cat in w['present'])
        return (-w['diversity'], -rare_bonus, -len(w['present']))

    all_windows.sort(key=sort_key)
    selected = all_windows[:args.max_samples]

    # Save results
    output_file = output_dir / f'{model_name}_samples.json'

    save_data = {
        'model': model_name,
        'window_size': args.window_size,
        'chromosomes': CHROMOSOMES,
        'structural_categories': STRUCTURAL_CATEGORIES,
        'regulatory_categories': REGULATORY_CATEGORIES,
        'id_to_token': {str(k): v for k, v in id_to_token.items()},
        'samples': selected
    }

    with open(output_file, 'w') as f:
        json.dump(save_data, f, indent=2)

    print(f"\n{'='*60}")
    print(f"RESULTS")
    print('='*60)
    print(f"Total diverse windows found: {len(all_windows)}")
    print(f"Selected top {len(selected)} samples")
    print(f"Saved to: {output_file}")

    # Summary
    print(f"\nDiversity distribution:")
    for div in range(7, 2, -1):
        count = sum(1 for w in selected if w['diversity'] == div)
        if count > 0:
            print(f"  Diversity {div}: {count} samples")

    print(f"\nTop 5 samples:")
    for i, w in enumerate(selected[:5]):
        print(f"  {i+1}. {w['chrom']}:{w['start']:,}-{w['end']:,} (div={w['diversity']})")
        print(f"     {w['present']}")


if __name__ == '__main__':
    main()
