#!/usr/bin/env python3
"""
Step 1: Genome Tokenization

Tokenizes genome sequences with optional annotations.
"""

import sys
sys.path.append('..')

import argparse
import pandas as pd
import gzip
import numpy as np
import json
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import time

from config import get_model_config, get_model_paths, REFERENCE_GENOME, ANNOTATION_PATHS


def load_gencode(path, chromosome=None):
    """Load GENCODE GTF (already 1-based)"""
    rows = []
    with gzip.open(path, 'rt') as f:
        for line in f:
            if line.startswith('#'):
                continue

            fields = line.strip().split('\t')
            if len(fields) < 9:
                continue

            chr_name = fields[0]
            if chromosome and chr_name != chromosome:
                continue

            feature_type = fields[2]
            if feature_type == 'Selenocysteine':
                continue

            rows.append({
                'chromosome': chr_name,
                'start': int(fields[3]),
                'end': int(fields[4]),
                'feature_type': feature_type,
                'source': 'gencode'
            })

    return pd.DataFrame(rows)


def load_encode(path, chromosome=None):
    """Load ENCODE BED, convert to 1-based coordinates"""
    cols = ['chromosome', 'start', 'end', 'name', 'score', 'strand',
            'thick_start', 'thick_end', 'rgb', 'ccre', 'encode_label',
            'z_score', 'ucsc_label', 'accession', 'description']

    df = pd.read_csv(path, sep='\t', names=cols, header=None)

    if chromosome:
        df = df[df['chromosome'] == chromosome]

    df['start'] = df['start'] + 1
    df['source'] = 'encode'
    return df[['chromosome', 'start', 'end', 'ucsc_label', 'source']]


def create_intervals_optimized(annotations):
    """Create non-overlapping intervals using sweep line algorithm"""
    if not annotations:
        return []

    events = []
    for ann in annotations:
        events.append((ann['start'], 0, ann))
        events.append((ann['end'], 1, ann))

    events.sort(key=lambda x: (x[0], x[1]))

    intervals = []
    active = set()
    last_pos = None

    for pos, event_type, ann in tqdm(events, desc="Creating intervals", leave=False):
        if last_pos is not None and pos != last_pos and active:
            intervals.append((last_pos, pos, list(active)))

        if event_type == 0:
            active.add(id(ann))
        else:
            active.discard(id(ann))

        last_pos = pos

    id_to_ann = {id(ann): ann for ann in annotations}
    final_intervals = []
    for start, end, active_ids in intervals:
        overlapping = [id_to_ann[aid] for aid in active_ids]
        final_intervals.append((start, end, overlapping))

    return final_intervals


def process_chromosome_annotations(gencode_path, encode_path, chromosome, use_encode=True):
    """Process single chromosome for annotation integration"""
    print(f"Processing annotations for {chromosome}...")

    gencode_df = load_gencode(gencode_path, chromosome)
    annotations = []

    for _, row in gencode_df.iterrows():
        annotations.append({
            'start': row['start'],
            'end': row['end'],
            'source': 'gencode',
            'feature_type': row['feature_type'],
            'ucsc_label': None
        })

    if use_encode:
        encode_df = load_encode(encode_path, chromosome)
        for _, row in encode_df.iterrows():
            annotations.append({
                'start': row['start'],
                'end': row['end'],
                'source': 'encode',
                'feature_type': None,
                'ucsc_label': row['ucsc_label']
            })

    intervals = create_intervals_optimized(annotations)

    result_rows = []
    for start, end, overlapping in intervals:
        result = {
            'start': start,
            'end': end,
            'gencode_gene': 0,
            'gencode_transcript': 0,
            'gencode_exon': 0,
            'gencode_CDS': 0,
            'gencode_UTR': 0,
            'gencode_start_codon': 0,
            'gencode_stop_codon': 0,
        }

        if use_encode:
            result.update({
                'ucsc_prom': 0,
                'ucsc_enhP': 0,
                'ucsc_enhD': 0,
                'ucsc_CTCF': 0,
                'ucsc_K4m3': 0
            })

        for ann in overlapping:
            if ann['source'] == 'gencode' and ann['feature_type']:
                col_name = f"gencode_{ann['feature_type']}"
                if col_name in result:
                    result[col_name] = 1
            elif ann['source'] == 'encode' and ann['ucsc_label']:
                col_name = f"ucsc_{ann['ucsc_label']}"
                if col_name in result:
                    result[col_name] = 1

        result_rows.append(result)

    return pd.DataFrame(result_rows)


class GenomeTokenizer:
    def __init__(self, fasta_path, output_dir, use_gencode=True, use_encode=True):
        self.fasta_path = fasta_path
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.use_gencode = use_gencode
        self.use_encode = use_encode

        self.vocabulary = {}
        self.vocab_counter = 0
        self.token_stats = defaultdict(int)

        self.gencode_features = ['gene', 'transcript', 'exon', 'CDS', 'UTR',
                                'start_codon', 'stop_codon']
        self.encode_features = ['prom', 'enhP', 'enhD', 'CTCF', 'K4m3']

    def create_token(self, base, gencode_flags, encode_flags):
        """Create token string from base and annotation flags"""
        if base not in 'ATGCN':
            base = 'N'

        if not self.use_gencode:
            return base

        active_gencode = [feat for i, feat in enumerate(self.gencode_features)
                         if gencode_flags[i] == 1]
        active_encode = []
        if self.use_encode:
            active_encode = [feat for i, feat in enumerate(self.encode_features)
                            if encode_flags[i] == 1]

        parts = [base]
        if active_gencode:
            parts.extend(active_gencode)
        if active_encode:
            parts.extend(active_encode)

        return '_'.join(parts)

    def get_or_add_token_id(self, token):
        """Get token ID, add to vocabulary if new"""
        if token not in self.vocabulary:
            self.vocabulary[token] = self.vocab_counter
            self.vocab_counter += 1

        token_id = self.vocabulary[token]
        self.token_stats[token] += 1
        return token_id

    def load_fasta_sequence(self, chromosome):
        """Simple FASTA parsing"""
        sequence = ""
        target_found = False

        with open(self.fasta_path, 'r') as f:
            for line in f:
                if line.startswith('>'):
                    if target_found:
                        break
                    if chromosome in line:
                        target_found = True
                    continue

                if target_found:
                    sequence += line.strip().upper()

        return sequence

    def process_chromosome(self, chromosome):
        """Process single chromosome to generate tokens"""
        print(f"\nProcessing tokens for {chromosome}...")
        start_time = time.time()

        sequence = self.load_fasta_sequence(chromosome)
        if not sequence:
            print(f"Error: {chromosome} not found in FASTA")
            return

        seq_length = len(sequence)
        print(f"Sequence length: {seq_length:,} bp")

        if not self.use_gencode:
            print(f"Creating sequence-only tokens...")
            tokens = []
            for base in tqdm(sequence, desc="Creating base tokens"):
                token_id = self.get_or_add_token_id(base if base in 'ATGCN' else 'N')
                tokens.append(token_id)
        else:
            annotations_df = process_chromosome_annotations(
                ANNOTATION_PATHS['gencode'],
                ANNOTATION_PATHS['encode'],
                chromosome,
                use_encode=self.use_encode
            )

            if annotations_df.empty:
                print(f"No annotations found, creating base-only tokens...")
                tokens = []
                for base in tqdm(sequence, desc="Creating base tokens"):
                    token = self.create_token(base, [0]*7, [0]*5)
                    token_id = self.get_or_add_token_id(token)
                    tokens.append(token_id)
            else:
                print(f"Creating annotation-aware tokens...")
                tokens = self.create_tokens_with_annotations(sequence, annotations_df)

        tokens_array = np.array(tokens, dtype=np.int32)
        output_path = self.output_dir / f"{chromosome}_tokens.npy"
        np.save(output_path, tokens_array)

        elapsed = time.time() - start_time
        print(f"Saved {len(tokens):,} tokens for {chromosome} ({elapsed:.1f}s)")
        print(f"Unique tokens so far: {len(self.vocabulary):,}")

    def create_tokens_with_annotations(self, sequence, annotations_df):
        """Create tokens using sequence and annotations"""
        seq_length = len(sequence)
        tokens = []

        print("Building annotation lookup...")
        annotation_lookup = {}

        for _, row in tqdm(annotations_df.iterrows(), total=len(annotations_df), desc="Processing annotations"):
            start = row['start'] - 1
            end = row['end']

            gencode_flags = [row[f'gencode_{feat}'] for feat in self.gencode_features]
            encode_flags = [0] * 5
            if self.use_encode:
                encode_flags = [row[f'ucsc_{feat}'] for feat in self.encode_features]

            for pos in range(start, min(end, seq_length)):
                if pos in annotation_lookup:
                    existing_g, existing_e = annotation_lookup[pos]
                    merged_g = [max(existing_g[i], gencode_flags[i]) for i in range(7)]
                    merged_e = [max(existing_e[i], encode_flags[i]) for i in range(5)]
                    annotation_lookup[pos] = (merged_g, merged_e)
                else:
                    annotation_lookup[pos] = (gencode_flags[:], encode_flags[:])

        print("Generating tokens...")
        for pos in tqdm(range(seq_length), desc="Creating tokens"):
            base = sequence[pos]

            if pos in annotation_lookup:
                gencode_flags, encode_flags = annotation_lookup[pos]
            else:
                gencode_flags, encode_flags = [0]*7, [0]*5

            token = self.create_token(base, gencode_flags, encode_flags)
            token_id = self.get_or_add_token_id(token)
            tokens.append(token_id)

        return tokens

    def save_vocabulary_and_stats(self):
        """Save vocabulary and statistics"""
        vocab_path = self.output_dir / "vocabulary.json"
        with open(vocab_path, 'w') as f:
            json.dump(self.vocabulary, f, indent=2)

        stats = {
            "total_tokens": sum(self.token_stats.values()),
            "vocabulary_size": len(self.vocabulary),
            "gencode_features": self.gencode_features if self.use_gencode else [],
            "encode_features": self.encode_features if self.use_encode else [],
            "token_frequencies": dict(self.token_stats)
        }

        stats_path = self.output_dir / "stats.json"
        with open(stats_path, 'w') as f:
            json.dump(stats, f, indent=2)

        print(f"\nVocabulary size: {len(self.vocabulary):,}")
        print(f"Total tokens generated: {sum(self.token_stats.values()):,}")


def main():
    parser = argparse.ArgumentParser(description='Step 1: Genome Tokenization')
    parser.add_argument('--model', required=True, choices=['seq', 'struct', 'full'])
    args = parser.parse_args()

    config = get_model_config(args.model)
    paths = get_model_paths(args.model)

    print(f"=== Step 1: Tokenization for {args.model.upper()} ===")
    print(f"Description: {config['description']}")
    print(f"Output: {paths['tokens_dir']}")

    chromosomes = [f'chr{i}' for i in range(1, 23)] + ['chrX', 'chrY']

    if args.model == 'full':
        print("\nUsing GENCODE + ENCODE annotations")
        tokenizer = GenomeTokenizer(REFERENCE_GENOME, paths['tokens_dir'],
                                   use_gencode=True, use_encode=True)
    elif args.model == 'struct':
        print("\nUsing GENCODE annotations only")
        tokenizer = GenomeTokenizer(REFERENCE_GENOME, paths['tokens_dir'],
                                   use_gencode=True, use_encode=False)
    elif args.model == 'seq':
        print("\nUsing sequence only (no annotations)")
        tokenizer = GenomeTokenizer(REFERENCE_GENOME, paths['tokens_dir'],
                                   use_gencode=False, use_encode=False)

    for chromosome in chromosomes:
        tokenizer.process_chromosome(chromosome)

    tokenizer.save_vocabulary_and_stats()
    print("\n=== Tokenization Complete! ===")


if __name__ == "__main__":
    main()
