"""Model building and configuration utilities."""

import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoConfig, AutoModel, AutoModelForCausalLM
from calib.multi_seq_calibrator import MultiSeqCalibratorConfig

from calib.config.args import TrainArgs

def build_tokenizer(model_name: str, max_seq_len: int = None):
    """Build and configure tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.padding_side = "right"

    if max_seq_len is not None:
        tokenizer.model_max_length = max_seq_len
    
    return tokenizer


def build_base_model_config(
    model_name: str,
    attn_backend: str = "flex_attention",
    base_num_layers: int = None
):
    """Build base model configuration for generating hidden states."""
    config = AutoConfig.from_pretrained(model_name)
    config.use_cache = False
    
    # Override number of layers if specified
    if base_num_layers is not None:
        config.num_hidden_layers = base_num_layers
    
    # Set both public and private names for backend to cover HF variants
    try:
        setattr(config, "_attn_implementation", attn_backend)
    except Exception:
        pass
    
    if torch.cuda.is_bf16_supported():
        config.torch_dtype = torch.bfloat16
    else:
        config.torch_dtype = torch.float16
    
    return config


def build_training_model_config(model_name, args, base_config=None):
    """Build training model configuration."""
    config = MultiSeqCalibratorConfig.from_pretrained(model_name)
    config.use_cache = False
    config._attn_implementation = args.attn_backend
    config.main_input_name = "inputs_embeds"
    
    config.group_size = args.group_size
    config.mlp_hidden_size = args.mlp_hidden_size
    config.architecture = args.architecture
    config.max_context_len = args.max_context_len
    config.increment_position_ids = args.increment_position_ids
    config.agent_emb = args.agent_emb
    config.attn_types = args.attn_types
    config.input_embeds_size = base_config.hidden_size
    config.node_features = args.node_features
    config.bin_aggregate = args.bin_aggregate
    config.group_softmax = args.group_softmax
    config.sum_group_softmax = args.sum_group_softmax
    config.attend_all_group_softmax = args.attend_all_group_softmax
    config.late_group_softmax = args.late_group_softmax
    config.no_early_node_features_projection = args.no_early_node_features_projection
    config.late_node_features_projection = args.late_node_features_projection
    config.late_node_features_projection_norm = args.late_node_features_projection_norm
    config.sum_bin_aggregate = args.sum_bin_aggregate
    config.causal_bin_aggregate = args.causal_bin_aggregate
    
    if args.hidden_size is not None:
        config.hidden_size = args.hidden_size
    
    if args.intermediate_size is not None:
        config.intermediate_size = args.intermediate_size

    if args.num_attention_heads is not None:
        config.num_attention_heads = args.num_attention_heads
    
    if args.num_key_value_heads is not None:
        config.num_key_value_heads = args.num_key_value_heads

    if args.num_hidden_layers is not None:
        config.num_hidden_layers = args.num_hidden_layers

    if args.hidden_act is not None:
        config.hidden_act = args.hidden_act
    

    # Training-specific config
    config.num_labels = 1
    config.problem_type = "single_label_classification"
    config.classifier_dropout = args.dropout
    
    if torch.cuda.is_bf16_supported():
        config.torch_dtype = torch.bfloat16
        print("Using bf16")
    else:
        config.torch_dtype = torch.float16
        print("Using fp16")
    
    return config


def build_base_model(
    model_name: str,
    config,
    device
):
    """Build base model for generating hidden states."""
    print(f"Building base model: {model_name}")
    causal_lm = AutoModelForCausalLM.from_pretrained(
        model_name,
        config=config,
        torch_dtype=getattr(config, "torch_dtype", None),
        device_map={"": device},
    )
    causal_lm.eval()  # Always in eval mode for base
    # Ensure model doesn't track gradients
    for param in causal_lm.parameters():
        param.requires_grad = False
    
    return causal_lm.model, causal_lm.lm_head


def build_training_model(
    config,
    device,
    load_model_from: str = None,
    args: TrainArgs = None,
    lm_head = None,
):
    """Build training model from configuration."""
    # Use custom model class for token classification
    model_cls = AutoModelForTokenClassification

    if load_model_from:
        print(f"Loading training model from checkpoint: {load_model_from}")
        model = model_cls.from_pretrained(
            load_model_from,
            torch_dtype=getattr(config, "torch_dtype", None),
            device_map={"": device},
        )
        model.config._attn_implementation = config._attn_implementation
    else:
        print(f"Building custom training model from scratch with {config.num_hidden_layers} layers")
        # Create custom model from config (not pretrained weights)
        model = model_cls.from_config(config=config)
        model = model.to(device)

    if config.architecture == "probs":
        model.lm_head = lm_head

    model.wm_group_size = args.wm_group_size
    model.wm_type = args.wm_type
    model.main_input_name = "inputs_embeds"
    
    return model


def build_models_and_tokenizer(
    args,
    base_device,
    training_device,
    save_dict,
) -> tuple:
    """
    Build base model, training model, and tokenizer with all configurations applied.
    
    Returns:
        Tuple of (base_model, training_model, tokenizer)
    """
    # Determine base model name (use base_model_name if provided, otherwise model_name)
    save_model_name = save_dict["args"].model_name
    training_model_name = args.model_name if args.model_name is not None else save_model_name
    base_model_name = args.base_model_name if args.base_model_name is not None else save_model_name
    tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else base_model_name
    
    # Build tokenizer
    tokenizer = build_tokenizer(tokenizer_name, args.max_seq_len)
    
    # Build base config and model
    base_config = build_base_model_config(base_model_name, args.base_attn_backend, args.base_num_layers)
    base_model, lm_head = build_base_model(base_model_name, base_config, base_device)
    
    # Build training config and model
    training_config = build_training_model_config(training_model_name, args, base_config)
    training_model = build_training_model(
        training_config, training_device, args.load_model_from, args, lm_head
    )
    
    # Set pad token ID
    training_model.config.pad_token_id = tokenizer.pad_token_id
    
    return base_model, lm_head, training_model, tokenizer
