import logging
from .app_router import run_app
from .base_builder import GeneratorPipelineBuilder

import torch
from transformers import AutoConfig
from specdecodes.models.utils.utils import DraftParams
from specdecodes.models.utils.cache_utils import create_kv_cache
from specdecodes.models.draft_models.classic_seq_sd_mb import ClassicSDDraftModel
from specdecodes.models.generators.classic_seq_sd_mb import ClassicSDGenerator

class ClassicSDBuilder(GeneratorPipelineBuilder):
    def __init__(self):
        super().__init__()
        # Base configurations.
        self.vram_limit_gb = None
        self.seed = 0
        self.device = "cuda:0"
        self.dtype = torch.bfloat16
        self.limit_min_output = False
        self.max_length = 1024 * 64
        self.batch_size = 1
        # For pg-19
        # self.limit_min_output = True
        # self.min_length = 1024 * 16
        # self.max_new_tokens = 512
        
        # Model paths.
        self.llm_path = "meta-llama/Llama-3.1-8B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-32B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-14B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-7B-Instruct"    
        # self.llm_path = "Qwen/Qwen2.5-0.5B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-1.5B-Instruct"

        self.draft_model_path = "meta-llama/Llama-3.2-1B-Instruct"
        # self.draft_model_path = "Qwen/Qwen2.5-0.5B-Instruct"
        
        # Generation parameters.
        self.do_sample = False
        self.temperature = 0
        
        # Generator-specific configurations.
        self.generator_kwargs = {
            "prefill_chunk_size": 2048,
        }
        self.draft_params = DraftParams(
            temperature=1,
            max_depth=3,
            topk_len=1,
            generator_kwargs=self.generator_kwargs,
        )

        # Recipe for quantization and offloading.
        self.recipe = None
        
        # Additional configurations.
        self.cache_implementation = "dynamic"
        self.warmup_iter = 3
        self.compile_mode = None

        # Attention implementation.
        self._attn_implementation = "flash_attention_2"
        
        # Profiling.
        self.generator_profiling = True
    
    def load_draft_model(self, target_model, tokenizer, draft_model_path, rope_scaling):
        draft_model = ClassicSDDraftModel.from_pretrained(
            draft_model_path,
            target_model=target_model,
            torch_dtype=self.dtype,
            device_map=self.device,
            eos_token_id=tokenizer.eos_token_id,
            _attn_implementation=self._attn_implementation,
            _load_in_8bit=self.load_in_8bit,
            _load_in_4bit=self.load_in_4bit,
            rope_scaling=rope_scaling,
        )
        return draft_model
    
    def load_generator(self, target_model, tokenizer, draft_model=None):
        generator = ClassicSDGenerator(
            target_model=target_model,
            tokenizer=tokenizer,
            draft_model=draft_model,
            draft_params=self.draft_params,
            cache_implementation=self.cache_implementation,
            profiling=self.generator_profiling,
            profiling_verbose=self.profiling_verbose,
            limit_min_output=self.limit_min_output,
            generator_kwargs=self.generator_kwargs,
        )
        generator.batch_size = self.batch_size
        return generator

    def load_kv_cache(self, target_model, draft_model):            
        if self.cache_implementation == "static":
            # This branch will only be taken in single-GPU mode due to configure_torch()
            if self.max_length is not None:
                if draft_model is not None:
                    max_cache_len = self.max_length + self.draft_params.max_verify_tokens
                else:
                    max_cache_len = self.max_length
            else:
                raise ValueError("max_length should be set for static cache.")
            
            logging.info(f"Creating Static KV Cache with length {max_cache_len} on {self.device}")
            # Create static kv-cache
            past_key_values = create_kv_cache(
                "static",
                max_cache_len=max_cache_len,
                max_batch_size=self.batch_size,
                config=target_model.config, 
                device=self.device,
                dtype=target_model.dtype,
            )
            if draft_model is not None:
                draft_past_key_values = create_kv_cache(
                    "static",
                    max_cache_len=max_cache_len,
                    max_batch_size=self.batch_size,
                    config=draft_model.config,  
                    device=self.device,
                    dtype=draft_model.dtype, 
                )
            else:
                draft_past_key_values = None
        else:
            # Create dynamic kv-cache (used for multi-GPU or by default)
            logging.info("Creating Dynamic KV Cache")
            past_key_values = create_kv_cache("dynamic")
            if draft_model is not None:
                draft_past_key_values = create_kv_cache("dynamic")
            else:
                draft_past_key_values = None
        
        return past_key_values, draft_past_key_values

    def build_models_and_tokenizer(self):
        """
        Build and return the main model, draft model, and tokenizer.
        """
        self.configure_torch() # This now detects GPUs

        config = AutoConfig.from_pretrained(self.llm_path, trust_remote_code=True)
        target_rope_scaling = getattr(config, "rope_scaling", None)
        if "qwen2" in self.llm_path.lower():
            original_max_pos = config.max_position_embeddings
            factor = self.max_length / original_max_pos
            if factor > 1.0:
                target_rope_scaling = {
                    "type": "yarn",
                    "factor": factor,
                    "original_max_position_embeddings": original_max_pos,
                }

        config = AutoConfig.from_pretrained(self.draft_model_path, trust_remote_code=True)
        draft_rope_scaling = getattr(config, "rope_scaling", None)
        if "qwen2" in self.draft_model_path.lower():
            original_max_pos = config.max_position_embeddings
            factor = self.max_length / original_max_pos
            if factor > 1.0:
                draft_rope_scaling = {
                    "type": "yarn",
                    "factor": factor,
                    "original_max_position_embeddings": original_max_pos,
                }

        model, tokenizer = self.load_model_and_tokenizer(self.llm_path, rope_scaling=target_rope_scaling)
        draft_model = self.load_draft_model(model, tokenizer, self.draft_model_path, rope_scaling=draft_rope_scaling)
        if target_rope_scaling is not None:
            model.config.max_position_embeddings = self.max_length
        if draft_rope_scaling is not None:
            draft_model.config.max_position_embeddings = self.max_length

        if self.recipe:
            target_config, draft_config = self.recipe.generate_configurations(
                target_model=model,
                draft_model=draft_model,
                max_length=self.max_length,
                cpu_offload_gb=self.cpu_offload_gb,
                dtype=self.dtype,
                device=self.device, # Note: self.device is primary, but recipe might need to know about multi-GPU
            )
            
            # Apply recipe (quantization/offloading)
            # This logic should be robust to device_map="auto" if the recipe
            # correctly iterates over model modules.
            if draft_model and draft_config and draft_config.get("quant_config"):
                self.recipe.apply_quantization(draft_model.model, draft_config["quant_config"], self.dtype, self.device)
            if target_config and target_config.get("quant_config"):
                self.recipe.apply_quantization(model, target_config["quant_config"], self.dtype, self.device)

            if draft_model and draft_config and draft_config.get("device_map"):
                self.recipe.apply_offloading(draft_model.model, draft_config["device_map"])
            if target_config and target_config.get("device_map"):
                self.recipe.apply_offloading(model, target_config["device_map"])

        # Add a dedicated pad token (so pad_token_id != eos_token_id)
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
        tokenizer.padding_side = "left"
        model.resize_token_embeddings(len(tokenizer))  # IMPORTANT after adding tokens
        if draft_model:
            draft_model.model.resize_token_embeddings(len(tokenizer))  # IMPORTANT after adding tokens
        return model, draft_model, tokenizer



if __name__ == "__main__":
    run_app(ClassicSDBuilder())