#!/usr/bin/env python3
'''
Configuration for annDNA project
'''

from pathlib import Path

# =============================================================================
# Root Directories
# =============================================================================
# DATA_ROOT = Path("")
RESULTS_ROOT = DATA_ROOT / "results"
WANDB_DIR = DATA_ROOT / "wandb"

# =============================================================================
# Model Configurations
# =============================================================================
MODELS = {
    'full': {
        'name': 'full',
        'description': 'Full model with GENCODE + ENCODE annotations',
        'vocab_size': 272,
        'd_model': 768,
        'nhead': 12,
        'num_layers': 12,
        'max_seq_len': 1002,
    },
    'struct': {
        'name': 'struct',
        'description': 'Model with GENCODE annotations only',
        'vocab_size': 59,
        'd_model': 768,
        'nhead': 12,
        'num_layers': 12,
        'max_seq_len': 1002,
    },
    'seq': {
        'name': 'seq',
        'description': 'Sequence-only baseline model',
        'vocab_size': 10,
        'd_model': 768,
        'nhead': 12,
        'num_layers': 12,
        'max_seq_len': 1002,
    },
    'grover': {
        'name': 'grover',
        'description': 'GROVER baseline model from HuggingFace (PoetschLab/GROVER)',
        'model_type': 'huggingface',
        'hf_model_name': 'PoetschLab/GROVER',
        'd_model': 768,
        'max_seq_len': 512,
    },
    'distilled': {
        'name': 'distilled',
        'description': 'Distilled student model (from full, using seq vocab)',
        'vocab_size': 10,
        'd_model': 768,
        'nhead': 4,
        'num_layers': 4,
        'max_seq_len': 1002,
        'base_vocab': 'seq',
        # 'model_path': Path(''),  # results/6_distillation/best_model.pt path
    },
}

# =============================================================================
# Dataset Paths
# =============================================================================
TRAITGYM_DATA = DATA_ROOT / "data" / "traitgym"

# Reference genome and annotations
REFERENCE_GENOME = DATA_ROOT / "data" / "reference" / "GRCh38.primary_assembly.genome.fasta"
ANNOTATION_PATHS = {
    'gencode': DATA_ROOT / "data" / "gencode" / "gencode.v49.annotation.gtf.gz",
    'encode': DATA_ROOT / "data" / "encode" / "encodeCcreCombined.bed.gz"
}

# =============================================================================
# Chromosome Parameters
# =============================================================================
CHROMOSOMES = [f'chr{i}' for i in range(1, 23)] + ['chrX']
TRAIN_CHROMOSOMES = [f'chr{i}' for i in range(1, 22)] + ['chrX']
VAL_CHROMOSOME = 'chr22'
EVAL_CHROMOSOMES = [f'chr{i}' for i in range(1, 22)] + ['chrX']
TRAIN_CHROMOSOMES_FULL = [f'chr{i}' for i in range(1, 23)] + ['chrX']

# Default random seeds
DEFAULT_RANDOM_SEEDS = [42, 123, 456, 789, 1111, 2222, 3333, 4444, 5555, 6666]

# =============================================================================
# Helper Functions
# =============================================================================
def get_model_config(model_name):
    """Get model configuration"""
    if model_name not in MODELS:
        raise ValueError(f"Unknown model: {model_name}. Available: {list(MODELS.keys())}")
    return MODELS[model_name]


def get_model_paths(model_name):
    """
    Get all paths for a model

    Structure:
    - results/1_preprocess/{model}/tokens/     - tokenized chromosomes
    - results/1_preprocess/{model}/processed/  - MLM samples, vocab
    - results/2_train/{model}/model/           - checkpoints
    """
    # Distilled model uses seq vocab and custom model path
    if model_name == 'distilled':
        config = MODELS['distilled']
        base_model = config.get('base_vocab', 'seq')
        preprocess_dir = RESULTS_ROOT / '1_preprocess' / base_model
        return {
            'tokens_dir': preprocess_dir / 'tokens',
            'processed_dir': preprocess_dir / 'processed',
            'vocab_file': preprocess_dir / 'processed' / 'vocab.json',
            'model_dir': config['model_path'].parent,
            'best_model': config['model_path'],
        }

    preprocess_dir = RESULTS_ROOT / '1_preprocess' / model_name
    train_dir = RESULTS_ROOT / '2_train' / model_name

    return {
        # Preprocess outputs
        'tokens_dir': preprocess_dir / 'tokens',
        'processed_dir': preprocess_dir / 'processed',
        'vocab_file': preprocess_dir / 'processed' / 'vocab.json',

        # Training outputs
        'model_dir': train_dir / 'model',
        'best_model': train_dir / 'model' / 'best_model.pt',
    }


def get_benchmark_paths():
    """Get paths for benchmark results"""
    benchmark_dir = RESULTS_ROOT / '4_benchmark'

    return {
        'root': benchmark_dir,
        'variants': benchmark_dir / 'variants',
        'tokens': benchmark_dir / 'tokens',
        'embeddings': benchmark_dir / 'embeddings',
        'metrics': benchmark_dir / 'metrics',
        'results': benchmark_dir / 'results',
        'plots': benchmark_dir / 'plots',
    }


def get_attention_paths():
    """Get paths for attention analysis"""
    attention_dir = RESULTS_ROOT / '5_attention'

    return {
        'root': attention_dir,
        'attention': attention_dir / 'attention',
        'results': attention_dir / 'results',
        'plots': attention_dir / 'plots',
    }
