import logging
from .app_router import run_app
from .base_builder import GeneratorPipelineBuilder

import torch
from transformers import AutoConfig
import json 
import os 
from specdecodes.models.utils.utils import DraftParams, build_layer_device_map
from specdecodes.models.draft_models.targetKV_seq_sd_mb import TargetKVSDDraftModel
from specdecodes.models.generators.targetKV_seq_sd_mb import TargetKVSDGenerator
from specdecodes.models.utils.cache_utils import create_kv_cache

class TargetKVSDBuilder(GeneratorPipelineBuilder):
    def __init__(self):
        super().__init__()
        # Base configurations.
        self.vram_limit_gb = None
        self.seed = 0
        self.device = "cuda:0"
        self.dtype = torch.bfloat16
        self.limit_min_output = False
        self.max_length = 1024 * 64
        self.batch_size = 1
        # For pg-19
        # self.limit_min_output = True
        # self.min_length = 1024 * 16
        # self.max_new_tokens = 512
        
        # Model paths.
        self.llm_path = "meta-llama/Llama-3.1-8B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-32B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-14B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-72B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-0.5B-Instruct"
        # self.llm_path = "Qwen/Qwen2.5-1.5B-Instruct"

        self.draft_model_path = "meta-llama/Llama-3.2-1B-Instruct"
        # self.draft_model_path = "Qwen/Qwen2.5-0.5B-Instruct"
        
        # Generation parameters.
        self.do_sample = False
        self.temperature = 0
        
        # Generator-specific configurations.
        self.generator_kwargs = {
            "prefill_chunk_size": 2048,
            "Target_KV_size": 512,
            "lag_size": 1,
            "window_size": 16,  # always keep recent sink_size tokens in the kv-cache
            "sink_size": 4,     # always keep first sink_size tokens in the kv-cache
            "SRH_path": "specdecodes/models/utils/compresskv/head_scores",
            "SRL_path": "specdecodes/models/utils/compresskv/layer_scores",

            # --- SR head select parameters ---
            "SRH_head_num": 4,  # None means all heads are important heads

            # --- CompressKV parameters ---
            # Could not set both Target_History_Only and Draft_Only to True
            "Target_History_Only": False,
            "Draft_Only": False,

            # --- Selection mode ---
            # 'all_layers_heads': Use all layers and all heads
            # 'all_layers_important_heads': Use all layers but only important heads  
            # 'important_layers_heads': Use important layers and important heads (default)
            "selection_mode": "important_layers_heads",

            # other parameters
            "batch_size": self.batch_size,
        }

        self.draft_params = DraftParams(
            temperature=1,
            max_depth=3,
            topk_len=1,
            generator_kwargs=self.generator_kwargs,
        )
        
        # Recipe for quantization and offloading.
        self.recipe = None
        
        # Additional configurations.
        self.cache_implementation = "dynamic"
        self.warmup_iter = 3
        self.compile_mode = None

        # Attention implementation.
        self._attn_implementation = "flash_attention_2"
        
        # Profiling.
        self.generator_profiling = True

    def compile_generator(self, generator):
        """
        Compile the generator's forward methods.
        In targetkv_sd, we only compile selected kv target forward.
        """
        # # This function is now only called if conditions are met
        # generator.target_model.forward = torch.compile(generator.target_model.forward, mode=self.compile_mode, dynamic=False, fullgraph=True)
        # if getattr(generator, 'draft_model', None) is not None:
        #     generator.draft_model.forward = torch.compile(generator.draft_model.forward, mode=self.compile_mode, dynamic=False, fullgraph=True)

        generator.compressKV_target_model_forward = torch.compile(
            generator.compressKV_target_model_forward, 
            mode=self.compile_mode, 
            dynamic=False, 
            fullgraph=True
        )

    def build_generator_pipeline(self, model, draft_model, tokenizer):
        """
        Build the generator pipeline using pre-built model, draft_model, and tokenizer.
        """
        past_kv, draft_past_kv = self.load_kv_cache(model, draft_model)

        generator = self.load_generator(model, tokenizer, draft_model)
        generator.eval()

        if self.compile_mode is not None:
            if not self.is_multigpu:
                logging.info(f"Applying torch.compile (mode: {self.compile_mode}) to generator.")
                self.compile_generator(generator)
            else:
                logging.warning(f"Skipping torch.compile. "
                                f"Multi-GPU: {self.is_multigpu}, "
                                f"Cache: {self.cache_implementation}. "
                                f"(Compile only runs with single-GPU and static cache)")

        self.post_process(generator, tokenizer, past_kv, draft_past_kv)

        return generator, tokenizer, past_kv, draft_past_kv


    def load_draft_model(self, target_model, tokenizer, draft_model_path, rope_scaling):
        draft_model = TargetKVSDDraftModel.from_pretrained(
            draft_model_path,
            target_model=target_model,
            torch_dtype=self.dtype,
            device_map=self.device,
            eos_token_id=tokenizer.eos_token_id,
            attn_implementation=self._attn_implementation,
            rope_scaling=rope_scaling,
        )
        return draft_model
    
    def load_generator(self, target_model, tokenizer, draft_model=None):
        target_kv_size = self.generator_kwargs.get("Target_KV_size", 512)
        if self.is_multigpu:
            layer_device_map = build_layer_device_map(target_model)
        else:
            layer_device_map = None
        target_key_values = create_kv_cache(
            "static",
            max_cache_len=target_kv_size + self.draft_params.max_verify_tokens,
            max_batch_size=self.generator_kwargs.get("batch_size", 1),
            config=target_model.config,
            layer_device_map=layer_device_map,
            device=self.device,
            dtype=target_model.dtype,
        )

        generator = TargetKVSDGenerator(
            target_model=target_model,
            tokenizer=tokenizer,
            draft_model=draft_model,
            draft_params=self.draft_params,
            cache_implementation=self.cache_implementation,
            profiling=self.generator_profiling,
            profiling_verbose=self.profiling_verbose,
            limit_min_output=self.limit_min_output,
            target_key_values=target_key_values,
            generator_kwargs=self.generator_kwargs,
        )
        generator.batch_size = self.batch_size
        
        # Get selection mode from generator_kwargs
        selection_mode = self.generator_kwargs.get('selection_mode', 'important_layers_heads')
        print(f"\n{'='*60}")
        print(f"Selection Mode: {selection_mode}")
        print(f"{'='*60}")
        
        # Get model configs for all layers
        target_num_layers = target_model.config.num_hidden_layers
        draft_num_layers = draft_model.model.config.num_hidden_layers
        
        if selection_mode == 'all_layers_heads':
            # Mode 1: Use all layers and all heads
            # Set SRH_head_num to None to select all heads
            original_srh_head_num = self.generator_kwargs.get('SRH_head_num')
            self.generator_kwargs['SRH_head_num'] = None
            self.draft_params.generator_kwargs['SRH_head_num'] = None
            
            # Use all layers
            all_target_layers = list(range(target_num_layers))
            all_draft_layers = list(range(draft_num_layers))
            
            generator.set_important_layers(all_target_layers)
            generator.set_important_head_idx(f"{self.generator_kwargs['SRH_path']}/{self.llm_path.split('/')[-1]}_head_idx.json")
            generator.draft_model.set_important_layers(all_draft_layers)
            generator.draft_model.set_important_head_idx(f"{self.generator_kwargs['SRH_path']}/{self.draft_model_path.split('/')[-1]}_head_idx.json")
            
            # print(f"Target Model: Using ALL {target_num_layers} layers and ALL heads")
            # print(f"Draft Model: Using ALL {draft_num_layers} layers and ALL heads")
            
        elif selection_mode == 'all_layers_important_heads':
            # Mode 2: Use all layers but only important heads
            # Keep SRH_head_num as is (important heads)
            
            # Use all layers
            all_target_layers = list(range(target_num_layers))
            all_draft_layers = list(range(draft_num_layers))
            
            generator.set_important_layers(all_target_layers)
            generator.set_important_head_idx(f"{self.generator_kwargs['SRH_path']}/{self.llm_path.split('/')[-1]}_head_idx.json")
            generator.draft_model.set_important_layers(all_draft_layers)
            generator.draft_model.set_important_head_idx(f"{self.generator_kwargs['SRH_path']}/{self.draft_model_path.split('/')[-1]}_head_idx.json")
            
            # print(f"Target Model: Using ALL {target_num_layers} layers and IMPORTANT heads (top {self.generator_kwargs.get('SRH_head_num')})")
            # print(f"Draft Model: Using ALL {draft_num_layers} layers and IMPORTANT heads (top {self.generator_kwargs.get('SRH_head_num')})")
            
        else:  # 'important_layers_heads' (default)
            # Mode 3: Use important layers and important heads
            generator.set_important_layers(self.important_layers['leader_layers'])
            generator.set_important_head_idx(f"{self.generator_kwargs['SRH_path']}/{self.llm_path.split('/')[-1]}_head_idx.json")
            generator.draft_model.set_important_layers(self.important_layers['follower_layers'])
            generator.draft_model.set_important_head_idx(f"{self.generator_kwargs['SRH_path']}/{self.draft_model_path.split('/')[-1]}_head_idx.json")
            
            # print(f"Target Model: Using IMPORTANT layers {self.important_layers['leader_layers']} and IMPORTANT heads (top {self.generator_kwargs.get('SRH_head_num')})")
            # print(f"Draft Model: Using IMPORTANT layers {self.important_layers['follower_layers']} and IMPORTANT heads (top {self.generator_kwargs.get('SRH_head_num')})")
        
        print(f"{'='*60}\n")
        return generator

    def load_kv_cache(self, target_model, draft_model):            
        if self.cache_implementation == "static":
            # This branch will only be taken in single-GPU mode due to configure_torch()
            if self.max_length is not None:
                if draft_model is not None:
                    max_cache_len = self.max_length + self.draft_params.max_verify_tokens
                else:
                    max_cache_len = self.max_length
            else:
                raise ValueError("max_length should be set for static cache.")
            
            logging.info(f"Creating Static KV Cache with length {max_cache_len} on {self.device}")
            # Create static kv-cache
            past_key_values = create_kv_cache(
                "static",
                max_cache_len=max_cache_len,
                max_batch_size=self.batch_size,
                config=target_model.config, 
                device=self.device,
                dtype=target_model.dtype,
            )
            if draft_model is not None:
                draft_past_key_values = create_kv_cache(
                    "static",
                    max_cache_len=max_cache_len,
                    max_batch_size=self.batch_size,
                    config=draft_model.config,  
                    device=self.device,
                    dtype=draft_model.dtype, 
                )
            else:
                draft_past_key_values = None
        else:
            # Create dynamic kv-cache (used for multi-GPU or by default)
            logging.info("Creating Dynamic KV Cache")
            past_key_values = create_kv_cache("dynamic")
            if draft_model is not None:
                draft_past_key_values = create_kv_cache("dynamic")
            else:
                draft_past_key_values = None
        
        return past_key_values, draft_past_key_values

    def build_models_and_tokenizer(self):
        """
        Build and return the main model, draft model, and tokenizer.
        """
        # Print all configuration parameters
        # print("\n" + "="*80)
        # print("CONFIGURATION PARAMETERS")
        # print("="*80)
        # print(f"batch_size: {self.batch_size}")
        # print(f"min_length: {self.min_length}")
        # print(f"max_length: {self.max_length}")
        # print(f"max_new_tokens: {getattr(self, 'max_new_tokens', 'NOT SET')}")
        # print(f"limit_min_output: {self.limit_min_output}")
        # print(f"selection_mode: {self.generator_kwargs.get('selection_mode', 'NOT SET')}")
        # print(f"draft_max_depth: {self.draft_params.max_depth}")
        # print(f"draft_topk_len: {self.draft_params.topk_len}")
        # print(f"device: {self.device}")
        # print(f"dtype: {self.dtype}")
        # print(f"Target_KV_size: {self.generator_kwargs.get('Target_KV_size')}")
        # print(f"SRH_head_num: {self.generator_kwargs.get('SRH_head_num')}")
        # print(f"ANALYSIS_MODE: {os.environ.get('ANALYSIS_MODE', 'NOT SET')}")
        # print("="*80 + "\n")
        
        self.configure_torch() # This now detects GPUs

        # load file for important layers
        filename = f"{self.generator_kwargs['SRL_path']}/{self.llm_path.split('/')[-1]}_{self.draft_model_path.split('/')[-1]}_{self.generator_kwargs['Target_KV_size']}_layer_budget_results.json"
        min_draft_select_size = self.generator_kwargs['sink_size'] + self.generator_kwargs['window_size']
        # check file existence
        # if not os.path.exists(filename):
        #     filename = f"{self.generator_kwargs['SRL_path']}/{self.llm_path.split('/')[-1]}_{self.draft_model_path.split('/')[-1]}_512_layer_budget_results.json"
        #     with open(filename, 'r') as f:
        #         # load file
        #         self.important_layers = json.load(f)["depth_"+str(self.draft_params.max_depth)]["lag_"+str(self.generator_kwargs['lag_size'])]
        #     self.generator_kwargs['Draft_Select_size'] = min_draft_select_size if self.generator_kwargs['Target_History_Only'] else (self.generator_kwargs['Target_KV_size'] if self.generator_kwargs['Draft_Only'] else self.important_layers['draft_size']/(512/self.generator_kwargs['Target_KV_size']))
        #     self.generator_kwargs['Draft_Select_size'] = int(self.generator_kwargs['Draft_Select_size'])
        # else:
        #     with open(filename, 'r') as f:
        #         # load file
        #         self.important_layers = json.load(f)["depth_"+str(self.draft_params.max_depth)]["lag_"+str(self.generator_kwargs['lag_size'])]
        #         # set Draft_Select_size based on generator_kwargs['Target_History_Only'] and generator_kwargs['Draft_Only']
        #         self.generator_kwargs['Draft_Select_size'] = min_draft_select_size if self.generator_kwargs['Target_History_Only'] else (self.generator_kwargs['Target_KV_size'] if self.generator_kwargs['Draft_Only'] else self.important_layers['draft_size'])

        filename = f"{self.generator_kwargs['SRL_path']}/{self.llm_path.split('/')[-1]}_{self.draft_model_path.split('/')[-1]}_SRH_optimized.json"
        with open(filename, 'r') as f:
            # load file
            self.important_layers = json.load(f)["depth_"+str(self.draft_params.max_depth)]["lag_"+str(self.generator_kwargs['lag_size'])]
        self.generator_kwargs['Draft_Select_size'] = min_draft_select_size if self.generator_kwargs['Target_History_Only'] else (self.generator_kwargs['Target_KV_size'] if self.generator_kwargs['Draft_Only'] else self.important_layers['draft_size']/(512/self.generator_kwargs['Target_KV_size']))
        self.generator_kwargs['Draft_Select_size'] = int(self.generator_kwargs['Draft_Select_size'])

        config = AutoConfig.from_pretrained(self.llm_path, trust_remote_code=True)
        target_rope_scaling = getattr(config, "rope_scaling", None)
        if "qwen2" in self.llm_path.lower():
            original_max_pos = config.max_position_embeddings
            factor = self.max_length / original_max_pos
            if factor > 1.0:
                target_rope_scaling = {
                    "type": "yarn",
                    "factor": factor,
                    "original_max_position_embeddings": original_max_pos,
                }

        config = AutoConfig.from_pretrained(self.draft_model_path, trust_remote_code=True)
        draft_rope_scaling = getattr(config, "rope_scaling", None)
        if "qwen2" in self.draft_model_path.lower():
            original_max_pos = config.max_position_embeddings
            factor = self.max_length / original_max_pos
            if factor > 1.0:
                draft_rope_scaling = {
                    "type": "yarn",
                    "factor": factor,
                    "original_max_position_embeddings": original_max_pos,
                }

        model, tokenizer = self.load_model_and_tokenizer(self.llm_path, rope_scaling=target_rope_scaling)
        draft_model = self.load_draft_model(model, tokenizer, self.draft_model_path, rope_scaling=draft_rope_scaling)

        if target_rope_scaling is not None:
            model.config.max_position_embeddings = self.max_length
        if draft_rope_scaling is not None:
            draft_model.config.max_position_embeddings = self.max_length

        if self.recipe:
            target_config, draft_config = self.recipe.generate_configurations(
                target_model=model,
                draft_model=draft_model,
                max_length=self.max_length,
                cpu_offload_gb=self.cpu_offload_gb,
                dtype=self.dtype,
                device=self.device, # Note: self.device is primary, but recipe might need to know about multi-GPU
            )
            
            # Apply recipe (quantization/offloading)
            # This logic should be robust to device_map="auto" if the recipe
            # correctly iterates over model modules.
            if draft_model and draft_config and draft_config.get("quant_config"):
                self.recipe.apply_quantization(draft_model.model, draft_config["quant_config"], self.dtype, self.device)
            if target_config and target_config.get("quant_config"):
                self.recipe.apply_quantization(model, target_config["quant_config"], self.dtype, self.device)

            if draft_model and draft_config and draft_config.get("device_map"):
                self.recipe.apply_offloading(draft_model.model, draft_config["device_map"])
            if target_config and target_config.get("device_map"):
                self.recipe.apply_offloading(model, target_config["device_map"])

        # Add a dedicated pad token (so pad_token_id != eos_token_id)
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
        tokenizer.padding_side = "left"
        model.resize_token_embeddings(len(tokenizer))  # IMPORTANT after adding tokens
        if draft_model:
            draft_model.model.resize_token_embeddings(len(tokenizer))  # IMPORTANT after adding tokens
        return model, draft_model, tokenizer


if __name__ == "__main__":
    run_app(TargetKVSDBuilder())