from ..criterions.base_criterion import BaseCriterion
import torch
from transformers import PreTrainedModel
from absl import logging

from ..modules.load_pretrained import load_pretrained_model, load_tokenizer, load_image_processor
from ..modules.mllm import MLLM

def build_models(_config):
    logging.info(f"[Trainer] Build models")
    models = {}
    for role in ['drf', 'tgt']:
        if _config[role] is None:
            continue
        model = load_pretrained_model(_config, role=role).cuda().eval()
        # model = MLLM(load_pretrained_model(_config, role=role).cuda().eval())
        models[role] = model
    return models

def build_tokenizers(_config):
    logging.info(f"[Trainer] Build tokenizers")
    tokenizers = {}
    for role in ['drf', 'tgt']:
        if _config[role] is None:
            continue
        tokenizers[role], eos_token_id, pad_token_id = load_tokenizer(_config, max_target_length=_config['max_target_length'], role=role)
    return tokenizers, eos_token_id, pad_token_id

def build_image_processors(_config):
    logging.info(f"[Trainer] Build image processors")
    image_processors = {}
    for role in ['drf', 'tgt']:
        if _config[role] is None:
            continue
        image_processors[role] = load_image_processor(_config, role=role)
    return image_processors

def get_decoding_class(_config):
    logging.info(f"[Trainer] Build decoding")
    if _config['decoding'] == 'ard':
        from ..modules._autoregressive_decoding import AutoregressiveDecoding
        return AutoregressiveDecoding
    elif _config['decoding'] == 'sd':
        from ..modules._speculative_decoding import SpeculativeDecoding
        return SpeculativeDecoding
    else:
        raise ValueError(f"Invalid decoding type: {_config['decoding']}")

def get_criterion(_config):
    return BaseCriterion(_config)

def warmup_generation(model: PreTrainedModel, tokenizer, warmup_steps=10):
    inputs = tokenizer("Hello, my dog is cute", return_tensors="pt").to(model.device)
    logging.info(f"[Trainer] Warmup steps: {warmup_steps}")
    
    for _ in range(warmup_steps):
        _ = model.generate(
                **inputs
            )
    torch.cuda.empty_cache()
    pass