#!/usr/bin/env python3
"""
Step 2: Prepare Training Dataset

Creates MLM (Masked Language Model) samples from tokenized chromosomes.
- annDNA-seq: Sequence only
- annDNA-struct: Sequence + Structure/GENCODE
- annDNA-full: Sequence + Structure + Regulation/ENCODE

"""

import sys
sys.path.append('..')

import argparse
import json
import numpy as np
from pathlib import Path
import random
from tqdm import tqdm

from config import get_model_config, get_model_paths, TRAIN_CHROMOSOMES, VAL_CHROMOSOME


# =============================================================================
# Preprocessing Parameters
# =============================================================================
MASK_PROB = 0.15
REPLACE_PROB = 0.8
RANDOM_PROB = 0.1
N_THRESHOLD = 0.5


# =============================================================================
# Utility Functions (from utils.py)
# =============================================================================
def load_vocab(vocab_file):
    """Load vocabulary from JSON file"""
    with open(vocab_file, 'r') as f:
        return json.load(f)


def add_special_tokens(vocab):
    """Add special tokens to vocabulary if not present"""
    special_tokens = ['<CLS>', '<SEP>', '<MASK>']
    for token in special_tokens:
        if token not in vocab:
            vocab[token] = len(vocab)
    return vocab


def load_chromosome_tokens(tokens_dir, chrom):
    """Load tokens for a chromosome"""
    token_file = Path(tokens_dir) / f"{chrom}_tokens.npy"
    if not token_file.exists():
        raise FileNotFoundError(f"Token file not found: {token_file}")
    return np.load(token_file, allow_pickle=True).tolist()


def count_n_bases(window_tokens, id_to_token):
    """Count N bases in a window of tokens"""
    n_count = 0
    for token_id in window_tokens:
        token = id_to_token.get(token_id, '')
        if token.startswith('N'):
            n_count += 1
    return n_count


def filter_n_windows(windows, id_to_token, n_threshold=0.5):
    """
    Filter windows with >n_threshold N bases

    Args:
        windows: List of token windows
        id_to_token: Dictionary mapping token IDs to token strings
        n_threshold: Maximum ratio of N bases allowed (default: 0.5)

    Returns:
        Filtered list of windows
    """
    filtered = []
    for window in windows:
        n_count = count_n_bases(window, id_to_token)
        n_ratio = n_count / len(window)
        if n_ratio <= n_threshold:
            filtered.append(window)
    return filtered


# =============================================================================
# Dataset Preparation
# =============================================================================
class DatasetPreparator:
    def __init__(self, tokens_dir, output_dir, vocab, window_size=1000):
        self.tokens_dir = Path(tokens_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.vocab = vocab
        self.id_to_token = {v: k for k, v in vocab.items()}
        self.window_size = window_size

    def create_windows(self, chrom):
        """Create regular windows from chromosome tokens"""
        tokens = load_chromosome_tokens(self.tokens_dir, chrom)

        windows = []
        for i in range(0, len(tokens) - self.window_size + 1, self.window_size):
            window_tokens = tokens[i:i + self.window_size]
            if len(window_tokens) == self.window_size:
                windows.append(window_tokens)

        print(f"{chrom}: {len(windows):,} windows created")
        return windows

    def create_mlm_samples(self, all_windows):
        """Create MLM samples from windows"""
        max_len = self.window_size + 2  # CLS + tokens + SEP
        num_samples = len(all_windows)

        input_ids = np.zeros((num_samples, max_len), dtype=np.int32)
        labels = np.full((num_samples, max_len), -100, dtype=np.int32)
        attention_mask = np.zeros((num_samples, max_len), dtype=np.int32)

        special_tokens = {'<CLS>', '<SEP>', '<MASK>'}
        random_token_ids = [tid for token, tid in self.vocab.items()
                           if token not in special_tokens]

        for idx, token_ids in enumerate(tqdm(all_windows, desc="Creating MLM samples")):
            seq = [self.vocab['<CLS>']] + token_ids + [self.vocab['<SEP>']]

            input_ids[idx, :len(seq)] = seq
            labels[idx, :len(seq)] = seq
            attention_mask[idx, :len(seq)] = 1

            for i in range(1, len(seq) - 1):
                if random.random() < MASK_PROB:
                    r = random.random()
                    if r < REPLACE_PROB:
                        input_ids[idx, i] = self.vocab['<MASK>']
                    elif r < REPLACE_PROB + RANDOM_PROB:
                        input_ids[idx, i] = random.choice(random_token_ids)
                    # else: keep original (10%)
                else:
                    labels[idx, i] = -100

            labels[idx, 0] = -100
            labels[idx, len(seq) - 1] = -100

        return input_ids, labels, attention_mask

    def process_chromosome(self, chrom):
        """Process a single chromosome"""
        print(f"\n=== Processing {chrom} ===")

        windows = self.create_windows(chrom)
        if not windows:
            return None, None, None

        filtered_windows = filter_n_windows(windows, self.id_to_token, N_THRESHOLD)
        filtered_count = len(windows) - len(filtered_windows)
        print(f"Filtered {filtered_count:,} windows with >{N_THRESHOLD*100}% N bases")
        print(f"Remaining: {len(filtered_windows):,} windows")

        if not filtered_windows:
            print(f"No valid windows for {chrom}")
            return None, None, None

        input_ids, labels, attention_mask = self.create_mlm_samples(filtered_windows)

        return input_ids, labels, attention_mask

    def prepare_dataset(self):
        """Main pipeline"""
        print("=== Dataset Preparation ===")
        print(f"Window size: {self.window_size}")
        print(f"N threshold: {N_THRESHOLD}")
        print(f"Mask probability: {MASK_PROB}")

        # Training data (chr1-21, chrX)
        print(f"\n=== Processing Training Data ({len(TRAIN_CHROMOSOMES)} chromosomes) ===")
        train_input_ids = []
        train_labels = []
        train_attention_mask = []

        for chrom in TRAIN_CHROMOSOMES:
            input_ids, labels, attention_mask = self.process_chromosome(chrom)
            if input_ids is not None:
                train_input_ids.append(input_ids)
                train_labels.append(labels)
                train_attention_mask.append(attention_mask)

        if train_input_ids:
            train_input_ids = np.concatenate(train_input_ids, axis=0)
            train_labels = np.concatenate(train_labels, axis=0)
            train_attention_mask = np.concatenate(train_attention_mask, axis=0)

            np.save(self.output_dir / "train_input_ids.npy", train_input_ids)
            np.save(self.output_dir / "train_labels.npy", train_labels)
            np.save(self.output_dir / "train_attention_mask.npy", train_attention_mask)

            print(f"\nTraining data: {train_input_ids.shape[0]:,} samples")

        # Validation data (chr22)
        print(f"\n=== Processing Validation Data ({VAL_CHROMOSOME}) ===")
        val_input_ids, val_labels, val_attention_mask = self.process_chromosome(VAL_CHROMOSOME)

        if val_input_ids is not None:
            np.save(self.output_dir / "val_input_ids.npy", val_input_ids)
            np.save(self.output_dir / "val_labels.npy", val_labels)
            np.save(self.output_dir / "val_attention_mask.npy", val_attention_mask)

            print(f"Validation data: {val_input_ids.shape[0]:,} samples")

        # Save vocab with special tokens
        vocab_path = self.output_dir / "vocab.json"
        with open(vocab_path, 'w') as f:
            json.dump(self.vocab, f, indent=2)

        print(f"\n=== Dataset Preparation Complete ===")
        print(f"Vocabulary size: {len(self.vocab)}")
        print(f"Data saved to: {self.output_dir}")


def main():
    parser = argparse.ArgumentParser(description='Step 2: Prepare Training Dataset')
    parser.add_argument('--model', required=True,
                        choices=['seq', 'struct', 'full']
                        )
    args = parser.parse_args()

    config = get_model_config(args.model)
    paths = get_model_paths(args.model)

    # Calculate window size from model config
    window_size = config['max_seq_len'] - 2  # Subtract CLS and SEP

    print(f"=== Preparing Dataset for {config['name']} ===")
    print(f"Tokens: {paths['tokens_dir']}")
    print(f"Output: {paths['processed_dir']}")

    vocab_file = paths['tokens_dir'] / 'vocabulary.json'
    vocab = load_vocab(vocab_file)
    vocab = add_special_tokens(vocab)

    preparator = DatasetPreparator(paths['tokens_dir'], paths['processed_dir'], vocab, window_size)
    preparator.prepare_dataset()


if __name__ == "__main__":
    main()
