#!/usr/bin/env python3
"""
Step 1: Prepare Distillation Data (Embedding-based)

Extracts teacher hidden states for embedding distillation.
Also creates sequence-only input tokens for distilled training.
"""

import sys
sys.path.append('..')

import os
import json
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from tqdm import tqdm
import h5py

from config import get_model_paths, get_model_config
from distill_config import get_distill_paths, DISTILL_CONFIG, TEACHER_CONFIG

sys.path.append('../2_train')
from model import annDNA


class TeacherDataset(Dataset):
    """Dataset for teacher inference"""
    def __init__(self, data_dir, split='train'):
        self.input_ids = np.load(f"{data_dir}/{split}_input_ids.npy")
        self.labels = np.load(f"{data_dir}/{split}_labels.npy")
        self.attention_mask = np.load(f"{data_dir}/{split}_attention_mask.npy")
        print(f"Loaded {len(self.input_ids):,} {split} samples")

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long),
            'idx': idx
        }


class annDNAWithHidden(nn.Module):
    """Wrapper to extract hidden states from annDNA"""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape

        # Position IDs
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)

        # Embeddings
        token_embeds = self.model.token_embedding(input_ids)
        pos_embeds = self.model.pos_embedding(pos_ids)
        embeddings = self.model.layer_norm(self.model.dropout(token_embeds + pos_embeds))

        # Attention mask
        mask = (attention_mask == 0) if attention_mask is not None else None

        # Transformer - get hidden states (before MLM head)
        hidden_states = self.model.transformer(embeddings, src_key_padding_mask=mask)

        return hidden_states  # [batch, seq_len, 768]


def load_teacher_model(model_name, device):
    """Load pretrained teacher model"""
    config = get_model_config(model_name)
    paths = get_model_paths(model_name)

    with open(paths['vocab_file']) as f:
        vocab = json.load(f)

    model = annDNA(
        vocab_size=len(vocab),
        d_model=config['d_model'],
        nhead=config['nhead'],
        num_layers=config['num_layers'],
        max_seq_len=config['max_seq_len']
    )

    checkpoint = torch.load(paths['best_model'], map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Wrap to extract hidden states
    model = annDNAWithHidden(model)
    model = model.to(device)
    model.eval()

    print(f"Loaded teacher model: {model_name}")
    print(f"  Vocab size: {len(vocab)}")
    print(f"  d_model: {config['d_model']}")

    return model, vocab


def convert_tokens_to_seq_only(teacher_tokens, teacher_vocab, distilled_vocab):
    """
    Convert teacher tokens (with annotations) to sequence-only tokens.
    """
    teacher_id_to_name = {v: k for k, v in teacher_vocab.items()}

    distilled_base_to_id = {
        'A': distilled_vocab['A'],
        'T': distilled_vocab['T'],
        'G': distilled_vocab['G'],
        'C': distilled_vocab['C'],
        'N': distilled_vocab['N'],
        '<PAD>': distilled_vocab['<PAD>'],
    }

    distilled_tokens = np.zeros_like(teacher_tokens)

    for i, tid in enumerate(teacher_tokens):
        token_name = teacher_id_to_name.get(tid, '<PAD>')

        if token_name == '<PAD>':
            distilled_tokens[i] = distilled_base_to_id['<PAD>']
        elif token_name.startswith('N'):
            distilled_tokens[i] = distilled_base_to_id['N']
        elif token_name.startswith('A'):
            distilled_tokens[i] = distilled_base_to_id['A']
        elif token_name.startswith('T'):
            distilled_tokens[i] = distilled_base_to_id['T']
        elif token_name.startswith('G'):
            distilled_tokens[i] = distilled_base_to_id['G']
        elif token_name.startswith('C'):
            distilled_tokens[i] = distilled_base_to_id['C']
        else:
            distilled_tokens[i] = distilled_base_to_id['<PAD>']

    return distilled_tokens


def convert_labels_to_seq_only(teacher_labels, teacher_vocab, distilled_vocab):
    """
    Convert teacher labels to sequence-only labels.
    Labels with -100 (ignore) remain -100.
    """
    teacher_id_to_name = {v: k for k, v in teacher_vocab.items()}

    distilled_base_to_id = {
        'A': distilled_vocab['A'],
        'T': distilled_vocab['T'],
        'G': distilled_vocab['G'],
        'C': distilled_vocab['C'],
        'N': distilled_vocab['N'],
    }

    distilled_labels = np.full_like(teacher_labels, -100)

    for i, tid in enumerate(teacher_labels):
        if tid == -100:
            continue

        token_name = teacher_id_to_name.get(tid, '')

        if token_name.startswith('A'):
            distilled_labels[i] = distilled_base_to_id['A']
        elif token_name.startswith('T'):
            distilled_labels[i] = distilled_base_to_id['T']
        elif token_name.startswith('G'):
            distilled_labels[i] = distilled_base_to_id['G']
        elif token_name.startswith('C'):
            distilled_labels[i] = distilled_base_to_id['C']
        elif token_name.startswith('N'):
            distilled_labels[i] = distilled_base_to_id['N']

    return distilled_labels


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default='0', help='GPU ID')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--split', default='train', choices=['train', 'val'])
    args = parser.parse_args()

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using GPU: {args.gpu}")

    distill_paths = get_distill_paths()
    distill_paths['teacher_hidden'].mkdir(parents=True, exist_ok=True)
    distill_paths['distilled_model'].mkdir(parents=True, exist_ok=True)

    # Load teacher
    teacher_name = DISTILL_CONFIG['teacher_model']
    teacher, teacher_vocab = load_teacher_model(teacher_name, device)
    teacher_paths = get_model_paths(teacher_name)

    # Load distilled vocab
    distilled_paths = get_model_paths(DISTILL_CONFIG['distilled_base'])
    with open(distilled_paths['vocab_file']) as f:
        distilled_vocab = json.load(f)

    # Load data
    dataset = TeacherDataset(teacher_paths['processed_dir'], args.split)
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Output files
    output_dir = distill_paths['teacher_hidden']
    h5_file = output_dir / f'{args.split}_hidden_states.h5'
    distilled_input_file = output_dir / f'{args.split}_distilled_input_ids.npy'
    distilled_label_file = output_dir / f'{args.split}_distilled_labels.npy'
    attention_mask_file = output_dir / f'{args.split}_attention_mask.npy'

    # Prepare arrays
    num_samples = len(dataset)
    seq_len = dataset.input_ids.shape[1]
    d_model = TEACHER_CONFIG['d_model']

    distilled_input_ids = np.zeros((num_samples, seq_len), dtype=np.int32)
    distilled_labels = np.zeros((num_samples, seq_len), dtype=np.int32)
    attention_masks = np.zeros((num_samples, seq_len), dtype=np.int32)

    print(f"\nExtracting teacher hidden states...")
    print(f"Samples: {num_samples:,}, Seq length: {seq_len}, d_model: {d_model}")

    # Create HDF5 for hidden states
    with h5py.File(h5_file, 'w') as hf:
        # Pre-allocate dataset
        hidden_ds = hf.create_dataset(
            'hidden_states',
            shape=(num_samples, seq_len, d_model),
            dtype=np.float16,
            chunks=(min(100, num_samples), seq_len, d_model)
        )

        hf.attrs['num_samples'] = num_samples
        hf.attrs['seq_len'] = seq_len
        hf.attrs['d_model'] = d_model

        # Process batches
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=f"Processing {args.split}"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                indices = batch['idx'].numpy()

                # Get hidden states from teacher
                hidden_states = teacher(input_ids, attention_mask)
                hidden_np = hidden_states.cpu().numpy().astype(np.float16)

                # Store hidden states
                for i, idx in enumerate(indices):
                    hidden_ds[idx] = hidden_np[i]

                # Convert tokens to seq-only
                for i, idx in enumerate(indices):
                    teacher_input = batch['input_ids'][i].numpy()
                    teacher_label = batch['labels'][i].numpy()

                    distilled_input_ids[idx] = convert_tokens_to_seq_only(
                        teacher_input, teacher_vocab, distilled_vocab
                    )
                    distilled_labels[idx] = convert_labels_to_seq_only(
                        teacher_label, teacher_vocab, distilled_vocab
                    )
                    attention_masks[idx] = batch['attention_mask'][i].numpy()

    # Save distilled data
    np.save(distilled_input_file, distilled_input_ids)
    np.save(distilled_label_file, distilled_labels)
    np.save(attention_mask_file, attention_masks)

    print(f"\nSaved:")
    print(f"  Teacher hidden states: {h5_file}")
    print(f"  Distilled input_ids: {distilled_input_file}")
    print(f"  Distilled labels: {distilled_label_file}")
    print(f"  Attention masks: {attention_mask_file}")

    # Verify
    with h5py.File(h5_file, 'r') as hf:
        print(f"\nVerification:")
        print(f"  Hidden states shape: {hf['hidden_states'].shape}")
        print(f"  Hidden states dtype: {hf['hidden_states'].dtype}")
        sample = hf['hidden_states'][0, 0, :10]
        print(f"  Sample values: {sample}")


if __name__ == '__main__':
    main()
