#!/usr/bin/env python3
"""
Step 1: Sample Sequences for Embedding Analysis

Sample sequences from different genomic categories:
- Structural: CDS, 5'UTR, 3'UTR, Intron, Intergenic
- Regulatory: Promoter, pELS, dELS, CTCF, H3K4me3
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import random
import argparse

import config

# Paths
# OUTPUT_DIR = Path('')  # embedding_analysis path
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Sample parameters
SAMPLE_SIZE = 1000  # per category
MAX_SEQUENCE_LENGTH = 500  # bp (max 512 for GROVER compatibility)
MIN_GAP = 10000  # minimum gap between samples to avoid overlap

# Chromosome sets
TRAIN_CHROMOSOMES = config.TRAIN_CHROMOSOMES  # chr1-21, chrX (학습 데이터)
VAL_CHROMOSOMES = [config.VAL_CHROMOSOME]      # chr22 (검증 데이터)


def load_gencode_annotations():
    """Load GENCODE annotations from tokenized data"""
    print("Loading GENCODE annotations from tokens...")

    # We'll infer from vocabulary which positions have which features
    # Load vocab to get GENCODE features
    # vocab_file = ''  # full/processed/vocab.json path
    import json
    with open(vocab_file) as f:
        vocab = json.load(f)

    # Find tokens with GENCODE features
    gencode_tokens = {
        'cds': set(),
        'utr': set(),
        'exon': set(),
        'transcript': set(),
    }

    for token, tid in vocab.items():
        if 'CDS' in token:
            gencode_tokens['cds'].add(tid)
        if 'UTR' in token:
            gencode_tokens['utr'].add(tid)
        if 'exon' in token:
            gencode_tokens['exon'].add(tid)
        if 'transcript' in token and 'exon' not in token:
            gencode_tokens['transcript'].add(tid)

    return gencode_tokens, vocab


def load_encode_annotations():
    """Load ENCODE annotations from tokenized data"""
    print("Loading ENCODE annotations from tokens...")

    # vocab_file = ''  # full/processed/vocab.json path
    import json
    with open(vocab_file) as f:
        vocab = json.load(f)

    encode_tokens = {
        'promoter': set(),
        'enhP': set(),
        'enhD': set(),
        'ctcf': set(),
        'k4m3': set(),
    }

    for token, tid in vocab.items():
        if 'prom' in token:
            encode_tokens['promoter'].add(tid)
        if 'enhP' in token:
            encode_tokens['enhP'].add(tid)
        if 'enhD' in token:
            encode_tokens['enhD'].add(tid)
        if 'CTCF' in token:
            encode_tokens['ctcf'].add(tid)
        if 'K4m3' in token:
            encode_tokens['k4m3'].add(tid)

    return encode_tokens, vocab


def sample_category_from_tokens(token_file, target_tokens, category_name,
                                n_samples=SAMPLE_SIZE, max_len=MAX_SEQUENCE_LENGTH):
    """Sample variable-length regions from chromosome tokens based on token IDs"""

    tokens = np.load(token_file, mmap_mode='r')
    chrom_len = len(tokens)

    # Find all positions with target tokens
    target_tokens_set = set(target_tokens)
    matching_mask = np.isin(tokens, list(target_tokens_set))

    # Find continuous regions
    regions = []
    in_region = False
    start = 0

    for i in range(len(matching_mask)):
        if matching_mask[i] and not in_region:
            start = i
            in_region = True
        elif not matching_mask[i] and in_region:
            # End of region, check if it's valid
            length = min(i - start, max_len)
            if length > 0:
                # Check if >50% are target tokens
                window = tokens[start:start+length]
                if np.isin(window, list(target_tokens_set)).sum() / len(window) >= 0.5:
                    regions.append({'start': start, 'end': start + length})
            in_region = False

    # Handle last region
    if in_region:
        length = min(len(matching_mask) - start, max_len)
        if length > 0:
            window = tokens[start:start+length]
            if np.isin(window, list(target_tokens_set)).sum() / len(window) >= 0.5:
                regions.append({'start': start, 'end': start + length})

    # Randomly sample from regions
    if len(regions) == 0:
        return []

    # Filter non-overlapping samples
    random.shuffle(regions)
    samples = []

    for region in regions:
        if len(samples) >= n_samples:
            break

        # Check if overlaps with existing samples
        overlaps = any(abs(region['start'] - s['start']) < MIN_GAP for s in samples)
        if not overlaps:
            samples.append({
                'start': region['start'],
                'end': region['end'],
                'category': category_name
            })

    return samples


def sample_structural(chromosomes, split_name):
    """Sample structural elements (GENCODE-based)"""
    print("\n" + "="*80)
    print(f"SAMPLING STRUCTURAL ELEMENTS ({split_name})")
    print("="*80)

    gencode_tokens, vocab = load_gencode_annotations()

    # Define base tokens (no annotation)
    base_tokens = set([vocab[b] for b in ['A', 'T', 'G', 'C', 'N'] if b in vocab])

    categories = {
        'cds': gencode_tokens['cds'],
        'utr': gencode_tokens['utr'],
        'intron': gencode_tokens['transcript'] - gencode_tokens['exon'],
        'intergenic': base_tokens,  # Only base tokens (no annotations)
    }

    all_samples = []

    for chrom in tqdm(chromosomes, desc="Chromosomes"):
        # token_file = f''  # full/tokens/{chrom}_tokens.npy
        if not Path(token_file).exists():
            continue

        for cat_name, cat_tokens in categories.items():
            samples = sample_category_from_tokens(
                token_file, cat_tokens, cat_name,
                n_samples=SAMPLE_SIZE // len(chromosomes) + 10
            )

            for sample in samples:
                all_samples.append({
                    'chrom': chrom,
                    'start': sample['start'],
                    'end': sample['end'],
                    'category': cat_name,
                    'analysis': 'structural',
                    'split': split_name
                })

    # Balance categories
    df = pd.DataFrame(all_samples)
    balanced_samples = []
    for cat in categories.keys():
        cat_samples = df[df['category'] == cat].sample(
            n=min(SAMPLE_SIZE, len(df[df['category'] == cat])),
            random_state=42
        )
        balanced_samples.append(cat_samples)

    df_balanced = pd.concat(balanced_samples, ignore_index=True)

    print(f"\nStructural samples ({split_name}):")
    print(df_balanced['category'].value_counts())

    return df_balanced


def sample_regulatory(chromosomes, split_name):
    """Sample regulatory elements (ENCODE-based)"""
    print("\n" + "="*80)
    print(f"SAMPLING REGULATORY ELEMENTS ({split_name})")
    print("="*80)

    encode_tokens, vocab = load_encode_annotations()

    categories = {
        'promoter': encode_tokens['promoter'],
        'pELS': encode_tokens['enhP'],
        'dELS': encode_tokens['enhD'],
        'ctcf': encode_tokens['ctcf'],
        'h3k4me3': encode_tokens['k4m3'],
    }

    all_samples = []

    for chrom in tqdm(chromosomes, desc="Chromosomes"):
        # token_file = f''  # full/tokens/{chrom}_tokens.npy path
        if not Path(token_file).exists():
            continue

        for cat_name, cat_tokens in categories.items():
            samples = sample_category_from_tokens(
                token_file, cat_tokens, cat_name,
                n_samples=SAMPLE_SIZE // len(chromosomes) + 10
            )

            for sample in samples:
                all_samples.append({
                    'chrom': chrom,
                    'start': sample['start'],
                    'end': sample['end'],
                    'category': cat_name,
                    'analysis': 'regulatory',
                    'split': split_name
                })

    # Balance categories
    df = pd.DataFrame(all_samples)
    balanced_samples = []
    for cat in categories.keys():
        cat_samples = df[df['category'] == cat].sample(
            n=min(SAMPLE_SIZE, len(df[df['category'] == cat])),
            random_state=42
        )
        balanced_samples.append(cat_samples)

    df_balanced = pd.concat(balanced_samples, ignore_index=True)

    print(f"\nRegulatory samples ({split_name}):")
    print(df_balanced['category'].value_counts())

    return df_balanced


def main():
    parser = argparse.ArgumentParser(description='Step 1: Sample Sequences')
    parser.add_argument('--analysis', type=str, default='all',
                       choices=['structural', 'regulatory', 'all'],
                       help='Which analysis to sample for')
    parser.add_argument('--split', type=str, default='both',
                       choices=['train', 'val', 'both'],
                       help='Which chromosome split to use')
    args = parser.parse_args()

    print("="*80)
    print("STEP 1: SAMPLING SEQUENCES")
    print("="*80)
    print(f"Sample size per category: {SAMPLE_SIZE}")
    print(f"Max sequence length: {MAX_SEQUENCE_LENGTH} bp")
    print(f"Train chromosomes: {len(TRAIN_CHROMOSOMES)} (chr1-21, chrX)")
    print(f"Val chromosomes: {len(VAL_CHROMOSOMES)} (chr22)")

    # Determine which splits to process
    splits = []
    if args.split in ['train', 'both']:
        splits.append(('train', TRAIN_CHROMOSOMES))
    if args.split in ['val', 'both']:
        splits.append(('val', VAL_CHROMOSOMES))

    # Sample for each split
    for split_name, chromosomes in splits:
        print(f"\n{'='*80}")
        print(f"SPLIT: {split_name.upper()}")
        print(f"{'='*80}")

        # Sample structural
        if args.analysis in ['structural', 'all']:
            df_structural = sample_structural(chromosomes, split_name)
            output_file = OUTPUT_DIR / f'structural_{split_name}_samples.tsv'
            df_structural.to_csv(output_file, sep='\t', index=False)
            print(f"\n✓ Saved: {output_file}")

        # Sample regulatory
        if args.analysis in ['regulatory', 'all']:
            df_regulatory = sample_regulatory(chromosomes, split_name)
            output_file = OUTPUT_DIR / f'regulatory_{split_name}_samples.tsv'
            df_regulatory.to_csv(output_file, sep='\t', index=False)
            print(f"\n✓ Saved: {output_file}")

    print("\n" + "="*80)
    print("STEP 1 COMPLETE")
    print("="*80)


if __name__ == '__main__':
    main()
