from math import e
from .app_router import run_app
from .base_builder import GeneratorPipelineBuilder

import torch
from transformers import AutoConfig
from specdecodes.models.generators.naive_mb import NaiveGenerator

class NaiveBuilder(GeneratorPipelineBuilder):
    def __init__(self):
        super().__init__()
        # Base configurations.
        self.vram_limit_gb = None
        self.seed = 0
        self.device = "cuda:0"
        self.dtype = torch.bfloat16
        self.limit_min_output = False
        self.max_length = 1024 * 64
        self.batch_size = 1
        # For pg-19
        # self.limit_min_output = True
        # self.min_length = 1024 * 16
        # self.max_new_tokens = 512

        # Model paths.
        self.llm_path = "meta-llama/Llama-3.1-8B-Instruct"
        # self.llm_path = "meta-llama/Llama-3.2-1B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-32B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-14B-Instruct"    
        # self.llm_path = "Qwen/Qwen2.5-72B-Instruct"    
        # self.llm_path = "Qwen/Qwen2.5-0.5B-Instruct"

        # Generation parameters.
        self.do_sample = False
        self.temperature = 0
        
        # Generator-specific configurations.
        self.generator_kwargs = {
            "prefill_chunk_size": 2048,
        }
        
        # Recipe for quantization and offloading.
        self.recipe = None
        
        # Additional configurations.
        self.cache_implementation = "dynamic"
        self.warmup_iter = 3
        self.compile_mode = None
        
        # Attention implementation.
        self._attn_implementation = "flash_attention_2"

        # Profiling
        self.generator_profiling = True
        
    def load_generator(self, target_model, tokenizer, draft_model=None):
        generator = NaiveGenerator(
            target_model=target_model,
            tokenizer=tokenizer,
            draft_model=draft_model,
            device=self.device,
            dtype=self.dtype,
            do_sample=self.do_sample,
            temperature=self.temperature,
            profiling_verbose=self.profiling_verbose,
            limit_min_output=self.limit_min_output,
            generator_kwargs=self.generator_kwargs,
        )
        generator.batch_size = self.batch_size
        return generator

    def build_models_and_tokenizer(self):
        """
        Build and return the main model, draft model, and tokenizer.
        """
        self.configure_torch() # This now detects GPUs

        config = AutoConfig.from_pretrained(self.llm_path, trust_remote_code=True)
        target_rope_scaling = getattr(config, "rope_scaling", None)
        if "qwen2" in self.llm_path.lower():
            original_max_pos = config.max_position_embeddings
            factor = self.max_length / original_max_pos
            if factor > 1.0:
                target_rope_scaling = {
                    "type": "yarn",
                    "factor": factor,
                    "original_max_position_embeddings": original_max_pos,
                }

        model, tokenizer = self.load_model_and_tokenizer(self.llm_path, rope_scaling=target_rope_scaling)
        draft_model = self.load_draft_model(model, tokenizer, self.draft_model_path)
        if target_rope_scaling is not None:
            model.config.max_position_embeddings = self.max_length

        if self.recipe:
            target_config, draft_config = self.recipe.generate_configurations(
                target_model=model,
                draft_model=draft_model,
                max_length=self.max_length,
                cpu_offload_gb=self.cpu_offload_gb,
                dtype=self.dtype,
                device=self.device, # Note: self.device is primary, but recipe might need to know about multi-GPU
            )
            
            # Apply recipe (quantization/offloading)
            # This logic should be robust to device_map="auto" if the recipe
            # correctly iterates over model modules.
            if draft_model and draft_config and draft_config.get("quant_config"):
                self.recipe.apply_quantization(draft_model.model, draft_config["quant_config"], self.dtype, self.device)
            if target_config and target_config.get("quant_config"):
                self.recipe.apply_quantization(model, target_config["quant_config"], self.dtype, self.device)

            if draft_model and draft_config and draft_config.get("device_map"):
                self.recipe.apply_offloading(draft_model.model, draft_config["device_map"])
            if target_config and target_config.get("device_map"):
                self.recipe.apply_offloading(model, target_config["device_map"])

        # Add a dedicated pad token (so pad_token_id != eos_token_id)
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
        tokenizer.padding_side = "left"
        model.resize_token_embeddings(len(tokenizer))  # IMPORTANT after adding tokens
        if draft_model:
            draft_model.model.resize_token_embeddings(len(tokenizer))  # IMPORTANT after adding tokens
        return model, draft_model, tokenizer

if __name__ == "__main__":
    run_app(NaiveBuilder())