import importlib
import os
import random
import sys
from os.path import basename, dirname, isfile, splitext

import numpy as np
import torch
from torch import nn
from transformers import AutoModelForCausalLM


def set_seed(seed: int = 8):
    '''Set seed for reproducibility.'''
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

def load_module(module_reference, name):
    # Check if it's a file path to a .py file
    if isfile(module_reference) and module_reference.endswith('.py'):
        mod_name = splitext(basename(module_reference))[0]
        mod_path = dirname(os.path.abspath(module_reference))
        sys.path.insert(0, mod_path)
        module = __import__(mod_name)
    else:
        # Assume it's an installed module
        module = importlib.import_module(module_reference)

    return getattr(module, name)


def all_token_ids_strict(
    tokenizer,
    include_special: bool = True,
) -> list[int]:
    '''
    Return the sorted set of valid token IDs.

    - Base vocab  : [0, tokenizer.vocab_size)
    - Added tokens: tokenizer.get_added_vocab().values()
    - Special     : tokenizer.all_special_ids
    '''
    base = set(range(tokenizer.vocab_size))

    special_ids = set(getattr(tokenizer, 'all_special_ids', []))

    ids = set(base)
    if include_special:
        ids |= special_ids
    else:
        ids -= special_ids  # ensure specials excluded even if include_added=True

    return sorted(ids)


def num_layers(model_id: str):
    layers_num = {
        'roneneldan/TinyStories-1M': 8,
        'roneneldan/TinyStories-8M': 8,
        'roneneldan/TinyStories-33M': 4,
        
        'openai-community/gpt2': 12,
        'openai-community/gpt2-medium': 24,
        'openai-community/gpt2-large': 36,

        'google/gemma-3-1b-pt': 26,
        'google/gemma-3-4b-pt': 34,
        'google/gemma-3-12b-pt': 48,
        
        'microsoft/Phi-4-mini-instruct': 32,
        'mistralai/Mistral-7B-v0.1': 32,
        'meta-llama/Llama-3.1-8B': 32,

        'Qwen/Qwen2.5-0.5B': 24,
    }

    if model_id not in layers_num:
        raise NotImplementedError(f'Model ID `{model_id}` is not supported!')
    
    return layers_num[model_id]

def _get_attr_by_path(root, path: str):
    """Get attribute by a dotted path, or raise AttributeError."""
    obj = root
    for part in path.split('.'):
        obj = getattr(obj, part)
    return obj

def replace_last_norm(model_id: str, model: AutoModelForCausalLM):
    attr_name = {
        'roneneldan/TinyStories-1M': 'transformer.ln_f',
        'roneneldan/TinyStories-8M': 'transformer.ln_f',
        'roneneldan/TinyStories-33M': 'transformer.ln_f',
        
        'openai-community/gpt2': 'transformer.ln_f',
        'openai-community/gpt2-medium': 'transformer.ln_f',
        'openai-community/gpt2-large': 'transformer.ln_f',

        'google/gemma-3-1b-pt': 'model.norm',
        'google/gemma-3-4b-pt': 'model.norm',
        'google/gemma-3-12b-pt': 'model.norm',
        
        'microsoft/Phi-4-mini-instruct': 'model.norm',
        'mistralai/Mistral-7B-v0.1': 'model.norm',
        'meta-llama/Llama-3.1-8B': 'model.norm',

        'Qwen/Qwen2.5-0.5B': 'model.norm',
    }

    if model_id not in attr_name:
        raise NotImplementedError(f'Model ID `{model_id}` is not supported!')
    
    parts = attr_name[model_id].split('.')
    parent_path, leaf = '.'.join(parts[:-1]), parts[-1]

    parent = _get_attr_by_path(model, parent_path) if parent_path else model
    model.norm_bk = getattr(parent, leaf)
    setattr(parent, leaf, nn.Identity())