#!/usr/bin/env python3
"""
Step 1: Prepare Benchmark Data

Load variant data and extract tokens from pre-tokenized chromosomes.

Datasets:
- mendelian: TraitGym mendelian diseases
- complex: TraitGym complex traits
- eqtl: GTEx eQTL variants
- clinvar: ClinVar pathogenic variants

Models:
- seq: Sequence only
- struct: Sequence + Structure (GENCODE)
- full: Sequence + Structure + Regulation (ENCODE)
- grover: Uses seq tokens
- distilled: Uses seq tokens
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import argparse
import json
from multiprocessing import Pool

from config import get_model_paths, get_benchmark_paths, TRAITGYM_DATA

# Paths
BENCHMARK_PATHS = get_benchmark_paths()
TOKENS_DIR = BENCHMARK_PATHS['tokens']
TOKENS_DIR.mkdir(exist_ok=True, parents=True)

# Benchmark data paths
# BENCHMARK_DATA_DIR = Path('')  # data/benchmark path

# Datasets
DATASETS = ['mendelian', 'complex', 'eqtl', 'clinvar']

# Models
MODELS = ['seq', 'struct', 'full']

# Ablation parameters
WINDOW_SIZES = [250, 500, 750]
VARIANT_POSITIONS = [0.25, 0.5, 0.75]

# Exclude chr22 (validation chromosome)
EXCLUDE_CHROMOSOMES = ['chr22']


def is_snp(ref, alt):
    """Check if variant is SNP"""
    return len(ref) == 1 and len(alt) == 1


def get_base_from_token(token_string):
    """Extract nucleotide base from token"""
    return token_string.split('_')[0]


def replace_base_in_token(token_string, new_base):
    """Replace base while preserving annotations"""
    parts = token_string.split('_')
    parts[0] = new_base
    return '_'.join(parts)


def get_token_id(token_string, vocab):
    """Get token ID with fallback"""
    if token_string in vocab:
        return vocab[token_string]
    base = get_base_from_token(token_string)
    if base in vocab:
        return vocab[base]
    return vocab.get('<UNK>', vocab.get('<PAD>', 0))


# =============================================================================
# Dataset Loaders
# =============================================================================

def load_traitgym_variants(dataset_name):
    """Load TraitGym variants (mendelian or complex)"""
    print(f"\nLoading TraitGym {dataset_name}...")

    parquet_file = TRAITGYM_DATA / f'{dataset_name}_traits.parquet'
    if not parquet_file.exists():
        print(f"  Warning: File not found: {parquet_file}")
        return []

    df = pd.read_parquet(parquet_file)
    print(f"  Total rows: {len(df):,}")

    variants = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc="  Parsing", leave=False):
        chrom = row.get('chr', row.get('chrom', row.get('chromosome')))
        if chrom is None:
            continue
        chrom = str(chrom)
        if not chrom.startswith('chr'):
            chrom = f'chr{chrom}'

        if chrom in EXCLUDE_CHROMOSOMES:
            continue

        pos = int(row.get('pos', row.get('position', 0)))
        ref = str(row.get('ref', row.get('REF', '')))
        alt = str(row.get('alt', row.get('ALT', '')))

        if not is_snp(ref, alt):
            continue

        label = int(row.get('label', row.get('pathogenic', 0)))

        variants.append({
            'chrom': chrom,
            'pos': pos,
            'ref': ref,
            'alt': alt,
            'label': label,
        })

    print(f"  Valid SNPs: {len(variants):,}")
    return variants


def load_eqtl_variants():
    """Load eQTL variants from VCF files"""
    print("\nLoading eQTL variants...")

    vcf_dir = BENCHMARK_DATA_DIR / 'vcf'
    if not vcf_dir.exists():
        print(f"  Warning: VCF directory not found: {vcf_dir}")
        return []

    variants = []
    pos_files = list(vcf_dir.glob('*_pos.vcf'))

    for pos_file in tqdm(pos_files, desc="  Loading tissues", leave=False):
        tissue = pos_file.stem.replace('data_gtex_fine_vcf_', '').replace('_pos', '')
        neg_file = vcf_dir / f'data_gtex_fine_vcf_{tissue}_neg.vcf'

        # Positive variants
        with open(pos_file) as f:
            for line in f:
                if line.startswith('#'):
                    continue
                parts = line.strip().split('\t')
                if len(parts) >= 5:
                    chrom, pos, _, ref, alt = parts[:5]
                    if not chrom.startswith('chr'):
                        chrom = f'chr{chrom}'
                    if chrom in EXCLUDE_CHROMOSOMES:
                        continue
                    if not is_snp(ref, alt):
                        continue
                    variants.append({
                        'chrom': chrom,
                        'pos': int(pos),
                        'ref': ref,
                        'alt': alt,
                        'label': 1,
                    })

        # Negative variants
        if neg_file.exists():
            with open(neg_file) as f:
                for line in f:
                    if line.startswith('#'):
                        continue
                    parts = line.strip().split('\t')
                    if len(parts) >= 5:
                        chrom, pos, _, ref, alt = parts[:5]
                        if not chrom.startswith('chr'):
                            chrom = f'chr{chrom}'
                        if chrom in EXCLUDE_CHROMOSOMES:
                            continue
                        if not is_snp(ref, alt):
                            continue
                        variants.append({
                            'chrom': chrom,
                            'pos': int(pos),
                            'ref': ref,
                            'alt': alt,
                            'label': 0,
                        })

    print(f"  Total: {len(variants):,}")
    return variants


def load_clinvar_variants():
    """Load ClinVar variants"""
    print("\nLoading ClinVar variants...")

    parquet_file = BENCHMARK_DATA_DIR / 'clinvar.parquet'
    if not parquet_file.exists():
        print(f"  Warning: File not found: {parquet_file}")
        return []

    df = pd.read_parquet(parquet_file)

    variants = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc="  Parsing", leave=False):
        chrom = str(row['chrom'])
        if not chrom.startswith('chr'):
            chrom = f'chr{chrom}'
        if chrom in EXCLUDE_CHROMOSOMES:
            continue

        ref, alt = row['ref'], row['alt']
        if not is_snp(ref, alt):
            continue

        variants.append({
            'chrom': chrom,
            'pos': int(row['pos']),
            'ref': ref,
            'alt': alt,
            'label': int(row['label']),
        })

    print(f"  Total: {len(variants):,}")
    return variants


def load_dataset(dataset_name):
    """Load dataset by name"""
    loaders = {
        'mendelian': lambda: load_traitgym_variants('mendelian'),
        'complex': lambda: load_traitgym_variants('complex'),
        'eqtl': load_eqtl_variants,
        'clinvar': load_clinvar_variants,
    }
    return loaders[dataset_name]()


# =============================================================================
# Token Extraction
# =============================================================================

def process_chromosome(args):
    """Worker function for parallel chromosome processing"""
    chrom, chrom_variants, token_dir, vocab, id_to_token, window_size, var_position = args

    CLS_ID = vocab.get('<CLS>')
    SEP_ID = vocab.get('<SEP>')

    token_file = Path(token_dir) / f'{chrom}_tokens.npy'
    if not token_file.exists():
        return [], [], []

    chr_tokens = np.load(token_file)

    var_idx = int(window_size * var_position)
    upstream = var_idx
    downstream = window_size - var_idx

    ref_tokens_list = []
    alt_tokens_list = []
    valid_variants = []

    for variant in chrom_variants:
        pos_0 = variant['pos'] - 1
        start = pos_0 - upstream
        end = pos_0 + downstream

        if start < 0 or end > len(chr_tokens):
            continue

        ref_window = chr_tokens[start:end].copy()
        alt_window = ref_window.copy()

        var_token_id = ref_window[var_idx]
        var_token_str = id_to_token.get(var_token_id, 'N')

        alt_token_str = replace_base_in_token(var_token_str, variant['alt'])
        alt_token_id = get_token_id(alt_token_str, vocab)
        alt_window[var_idx] = alt_token_id

        ref_with_special = np.concatenate([[CLS_ID], ref_window, [SEP_ID]])
        alt_with_special = np.concatenate([[CLS_ID], alt_window, [SEP_ID]])

        ref_tokens_list.append(ref_with_special)
        alt_tokens_list.append(alt_with_special)
        valid_variants.append(variant)

    return ref_tokens_list, alt_tokens_list, valid_variants


def extract_tokens(variants, model_name, window_size, var_position, n_workers=8):
    """Extract tokens for all variants"""
    model_paths = get_model_paths(model_name)
    token_dir = str(model_paths['tokens_dir'])

    with open(model_paths['vocab_file']) as f:
        vocab = json.load(f)

    id_to_token = {v: k for k, v in vocab.items()}

    if '<CLS>' not in vocab or '<SEP>' not in vocab:
        raise ValueError(f"Missing special tokens in vocab")

    # Group by chromosome
    chr_to_variants = {}
    for v in variants:
        chrom = v['chrom']
        if chrom not in chr_to_variants:
            chr_to_variants[chrom] = []
        chr_to_variants[chrom].append(v)

    # Parallel processing
    work_args = [
        (chrom, chrom_variants, token_dir, vocab, id_to_token, window_size, var_position)
        for chrom, chrom_variants in chr_to_variants.items()
    ]

    ref_tokens_list = []
    alt_tokens_list = []
    valid_variants = []

    with Pool(n_workers) as pool:
        results = list(tqdm(
            pool.imap(process_chromosome, work_args),
            total=len(work_args),
            desc=f"    {model_name}",
            leave=False
        ))

    for ref_list, alt_list, var_list in results:
        ref_tokens_list.extend(ref_list)
        alt_tokens_list.extend(alt_list)
        valid_variants.extend(var_list)

    if len(ref_tokens_list) == 0:
        return np.array([]), np.array([]), []

    ref_tokens = np.array(ref_tokens_list, dtype=np.int32)
    alt_tokens = np.array(alt_tokens_list, dtype=np.int32)

    return ref_tokens, alt_tokens, valid_variants


def main():
    parser = argparse.ArgumentParser(description='Step 1: Prepare Benchmark Data')
    parser.add_argument('--dataset', default='all', choices=['all'] + DATASETS)
    parser.add_argument('--model', default='all', choices=['all'] + MODELS)
    parser.add_argument('--window', default='all')
    parser.add_argument('--position', default='all')
    parser.add_argument('--workers', type=int, default=8)
    args = parser.parse_args()

    print("=" * 70)
    print("STEP 1: PREPARE BENCHMARK DATA")
    print("=" * 70)

    datasets = DATASETS if args.dataset == 'all' else [args.dataset]
    models = MODELS if args.model == 'all' else [args.model]
    windows = WINDOW_SIZES if args.window == 'all' else [int(args.window)]
    positions = VARIANT_POSITIONS if args.position == 'all' else [float(args.position)]

    print(f"Datasets: {datasets}")
    print(f"Models: {models}")
    print(f"Windows: {windows}")
    print(f"Positions: {positions}")
    print(f"Output: {TOKENS_DIR}")

    for dataset_name in datasets:
        print(f"\n{'='*70}")
        print(f"Dataset: {dataset_name.upper()}")
        print("=" * 70)

        variants = load_dataset(dataset_name)
        if not variants:
            print(f"  No variants loaded, skipping")
            continue

        n_pos = sum(1 for v in variants if v['label'] == 1)
        n_neg = sum(1 for v in variants if v['label'] == 0)
        print(f"  Variants: {len(variants):,} (pos={n_pos:,}, neg={n_neg:,})")

        for window in windows:
            for position in positions:
                print(f"\n  Window: {window}, Position: {position}")

                for model_name in models:
                    try:
                        ref_tokens, alt_tokens, valid_variants = extract_tokens(
                            variants, model_name, window, position, args.workers
                        )

                        if len(ref_tokens) == 0:
                            print(f"    {model_name}: No valid tokens")
                            continue

                        labels = np.array([v['label'] for v in valid_variants], dtype=np.int32)

                        # Save tokens
                        pos_str = f"pos{position}".replace(".", "")
                        output_prefix = f"{model_name}_{dataset_name}_w{window}_{pos_str}"

                        np.savez_compressed(
                            TOKENS_DIR / f"{output_prefix}_tokens.npz",
                            ref_tokens=ref_tokens,
                            alt_tokens=alt_tokens,
                            labels=labels,
                        )

                        # Save metadata
                        meta_df = pd.DataFrame(valid_variants)
                        meta_df.to_csv(TOKENS_DIR / f"{output_prefix}_metadata.csv", index=False)

                        print(f"    {model_name}: {len(ref_tokens):,} variants saved")

                    except Exception as e:
                        print(f"    {model_name}: Error - {e}")
                        import traceback
                        traceback.print_exc()

    print("\n" + "=" * 70)
    print("STEP 1 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
