import logging
import os
import random
from typing import Any, Dict, Tuple, Optional
from types import SimpleNamespace

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from specdecodes.models.utils.cache_utils import create_kv_cache
from specdecodes.models.utils.monkey_patch import apply_llama3_attention_monkey_patch, apply_qwen2_attention_monkey_patch
from specdecodes.models.generators.naive import NaiveGenerator
from .app_router import run_app

LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
logging.basicConfig(level=LOGLEVEL)

class GeneratorPipelineBuilder:
    """
    Builder class to construct the generation pipeline.
    
    This class handles:
      - Torch configuration (precision, seeding, multi-GPU detection)
      - Loading the model and tokenizer (with device_map="auto" for multi-GPU)
      - Generating configuration dictionaries via the recipe
      - Applying quantization and offloading through the recipe (if applicable)
      - Building and conditionally compiling the generator pipeline
    """
    def __init__(self):
        # Base configurations.
        self.vram_limit_gb: Optional[int] = None
        self.seed = 0
        self.device = "cuda:0"  
        self.dtype = torch.float16
        self.load_in_8bit = False
        self.load_in_4bit = False
        
        # Multi-GPU flags
        self.n_gpu = 0
        self.is_multigpu = False

        # Model paths.
        self.llm_path = "meta-llama/Llama-3.1-8B-Instruct"
        self.draft_model_path = None

        # Generation parameters.
        self.max_length = 2048
        self.do_sample = False
        self.temperature = 0.0

        # Generator-specific configurations.
        self.generator_kwargs = {}
        self.draft_params = None

        # Attention implementation.
        self._attn_implementation = "sdpa"

        # Additional configurations.
        self.cache_implementation = "dynamic"
        self.warmup_iter = 0
        self.compile_mode = None

        # Recipe for quantization and offloading.
        self.recipe = None
        self.cpu_offload_gb: Optional[int] = None

        # Profiling and printing settings.
        self.generator_profiling: bool = True
        self.profiling_verbose: bool = True
        self.print_time: bool = True
        self.print_message: bool = True

        # Benchmarking/logging directories.
        self.out_dir: Optional[str] = None
        self.log_dir: str = "experiments"
        
    @property
    def args(self) -> Dict[str, Any]:
        """
        Return all attributes of the class as a dictionary.
        """
        my_dict = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and not callable(v)}
        return SimpleNamespace(**my_dict)
        
    
    def configure_torch(self):
        """
        Set up torch configurations for reproducibility and performance.
        Detects GPU count and sets multi-GPU flag.
        """
        torch.set_float32_matmul_precision('high')
        torch.manual_seed(self.seed)
        random.seed(self.seed)

        if torch.cuda.is_available():
            self.n_gpu = torch.cuda.device_count()
            if self.n_gpu > 1:
                self.is_multigpu = True
                self.device = "cuda:0" # Main device for inputs/cache
                logging.info(f"Detected {self.n_gpu} GPUs. Running in multi-GPU mode. Primary device: {self.device}")
                if self.cache_implementation == "static":
                    logging.warning("Static cache is not supported with multi-GPU. Forcing dynamic cache.")
                    self.cache_implementation = "dynamic"
            else:
                self.n_gpu = 1
                self.is_multigpu = False
                self.device = "cuda:0"
                logging.info(f"Detected 1 GPU. Running in single-GPU mode on {self.device}.")
        else:
            logging.warning("No CUDA GPUs detected. Using CPU.")
            self.device = "cpu"
            self.n_gpu = 0
            self.is_multigpu = False

    def load_model_and_tokenizer(self, model_path: str, rope_scaling: Dict[str, Any]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """
        Load a model and tokenizer from the specified model path.
        Handles single-GPU, multi-GPU (auto device_map), and CPU offload.
        """
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        # Determine device_map and max_memory
        max_memory = None
        if self.recipe and self.recipe.offloader:
            device_map = 'cpu'
            logging.info("Recipe includes offloader. Loading model on CPU for offloading.")
        elif self.is_multigpu:
            device_map = "auto"
            logging.info("Multi-GPU detected. Using device_map='auto'.")
            if self.vram_limit_gb is not None:
                # Apply VRAM limit to all GPUs
                max_memory = {i: f"{self.vram_limit_gb}GiB" for i in range(self.n_gpu)}
                logging.info(f"Applying VRAM limit of {self.vram_limit_gb}GiB to all {self.n_gpu} GPUs.")
        else:
            device_map = self.device
            logging.info(f"Single-GPU mode. Loading model on {self.device}.")
            if self.vram_limit_gb is not None and self.device != "cpu":
                # Apply VRAM limit for single GPU
                max_memory = {int(self.device.split(":")[-1]): f"{self.vram_limit_gb}GiB"}
                logging.info(f"Applying VRAM limit of {self.vram_limit_gb}GiB to {self.device}.")

        load_kwargs = {
            "low_cpu_mem_usage": True,
            "max_memory": max_memory,
            "_attn_implementation": self._attn_implementation,
        }

        if self.load_in_4bit:
            logging.info("Loading model in 4-bit mode (using BitsAndBytesConfig, device_map='auto').")
            
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,  
                bnb_4bit_quant_type="nf4",       
                bnb_4bit_compute_dtype=self.dtype
            )
            
            load_kwargs["quantization_config"] = bnb_config
            load_kwargs["device_map"] = "auto"  

        elif self.load_in_8bit:
            logging.info("Loading model in 8-bit mode (using BitsAndBytesConfig, device_map='auto').")
            
            bnb_config = BitsAndBytesConfig(
                load_in_8bit=True,
            )
            
            load_kwargs["quantization_config"] = bnb_config
            load_kwargs["device_map"] = "auto"  
        
        else:
            logging.info(f"Loading model in {self.dtype} mode.")
            load_kwargs["torch_dtype"] = self.dtype
            load_kwargs["device_map"] = device_map 

        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            rope_scaling = rope_scaling,
            **load_kwargs
        )
        
        return model, tokenizer

    def load_draft_model(self, target_model=None, tokenizer=None, draft_model_path=None):
        """
        Load a draft model if a draft model path is provided.
        Returns None if no draft model is needed.
        """
        if draft_model_path:
            # Implement draft model loading logic if needed.
            # Note: For multi-GPU, you might want to load this on self.device
            # or use device_map="auto" if it's also large.
            # For simplicity, we'll assume it's small or handled by recipe.
            logging.warning("Draft model loading in multi-GPU setup is not fully implemented.")
            return None
        return None
    
    def load_kv_cache(self, target_model, draft_model):            
        if self.cache_implementation == "static":
            # This branch will only be taken in single-GPU mode due to configure_torch()
            if self.max_length is not None:
                if draft_model is not None:
                    max_cache_len = self.max_length + self.draft_params.max_verify_tokens
                else:
                    max_cache_len = self.max_length
            else:
                raise ValueError("max_length should be set for static cache.")
            
            logging.info(f"Creating Static KV Cache with length {max_cache_len} on {self.device}")
            # Create static kv-cache
            past_key_values = create_kv_cache(
                "static",
                max_cache_len=max_cache_len,
                max_batch_size=self.batch_size,
                config=target_model.config, 
                device=self.device,
                dtype=target_model.dtype,
            )
            if draft_model is not None:
                draft_past_key_values = create_kv_cache(
                    "static",
                    max_cache_len=max_cache_len,
                    max_batch_size=self.batch_size,
                    config=draft_model.config,  
                    device=self.device,
                    dtype=draft_model.dtype, 
                )
            else:
                draft_past_key_values = None
        else:
            # Create dynamic kv-cache (used for multi-GPU or by default)
            logging.info("Creating Dynamic KV Cache")
            past_key_values = create_kv_cache("dynamic")
            if draft_model is not None:
                draft_past_key_values = create_kv_cache("dynamic")
            else:
                draft_past_key_values = None
        
        return past_key_values, draft_past_key_values
    
    def load_generator(self, target_model, tokenizer, draft_model=None):
        """
        Initialize the generator with the target model, tokenizer, and draft model.
        """
        # Apply monkey patch for self.llm_path attention if needed
        if "llama-3" in self.llm_path.lower():
            apply_llama3_attention_monkey_patch(self.cache_implementation)
        elif "qwen2.5" in self.llm_path.lower():
            apply_qwen2_attention_monkey_patch(self.cache_implementation)  
        else: # Raise a warning if no known attention patch is applied
            logging.warning("No specific attention monkey patch applied for this model.")

        generator = NaiveGenerator(
            target_model=target_model,
            tokenizer=tokenizer,
            draft_model=draft_model,
            draft_params=self.draft_params,
            cache_implementation=self.cache_implementation,
            profiling=self.generator_profiling,
            profiling_verbose=self.profiling_verbose,
            generator_kwargs=self.generator_kwargs,
        )
        return generator

    def compile_generator(self, generator):
        """
        Compile the generator's forward methods.
        """
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        
        # This function is now only called if conditions are met
        dynamic = self.cache_implementation == "dynamic"
        print("Compiling generator...")
        generator.target_model.forward = torch.compile(generator.target_model.forward, mode=self.compile_mode, dynamic=dynamic, fullgraph=not dynamic)
        if getattr(generator, 'draft_model', None) is not None:
            generator.draft_model.forward = torch.compile(generator.draft_model.forward, mode=self.compile_mode, dynamic=dynamic, fullgraph=not dynamic)
    
    def post_process(self, generator, tokenizer, past_kv, draft_past_kv):
        pass
    
    def build_models_and_tokenizer(self):
        """
        Build and return the main model, draft model, and tokenizer.
        """
        self.configure_torch() # This now detects GPUs
        model, tokenizer = self.load_model_and_tokenizer(self.llm_path)
        draft_model = self.load_draft_model(model, tokenizer, self.draft_model_path)

        if self.recipe:
            target_config, draft_config = self.recipe.generate_configurations(
                target_model=model,
                draft_model=draft_model,
                max_length=self.max_length,
                cpu_offload_gb=self.cpu_offload_gb,
                dtype=self.dtype,
                device=self.device, # Note: self.device is primary, but recipe might need to know about multi-GPU
            )
            
            # Apply recipe (quantization/offloading)
            # This logic should be robust to device_map="auto" if the recipe
            # correctly iterates over model modules.
            if draft_model and draft_config and draft_config.get("quant_config"):
                self.recipe.apply_quantization(draft_model.model, draft_config["quant_config"], self.dtype, self.device)
            if target_config and target_config.get("quant_config"):
                self.recipe.apply_quantization(model, target_config["quant_config"], self.dtype, self.device)

            if draft_model and draft_config and draft_config.get("device_map"):
                self.recipe.apply_offloading(draft_model.model, draft_config["device_map"])
            if target_config and target_config.get("device_map"):
                self.recipe.apply_offloading(model, target_config["device_map"])

        return model, draft_model, tokenizer
    
    def build_generator_pipeline(self, model, draft_model, tokenizer):
        """
        Build the generator pipeline using pre-built model, draft_model, and tokenizer.
        """
        past_kv, draft_past_kv = self.load_kv_cache(model, draft_model)

        generator = self.load_generator(model, tokenizer, draft_model)
        generator.eval()

        if self.compile_mode is not None:
            if not self.is_multigpu and self.cache_implementation == "static":
                logging.info(f"Applying torch.compile (mode: {self.compile_mode}) to generator.")
                self.compile_generator(generator)
            else:
                logging.warning(f"Skipping torch.compile. "
                                f"Multi-GPU: {self.is_multigpu}, "
                                f"Cache: {self.cache_implementation}. "
                                f"(Compile only runs with single-GPU and static cache)")

        self.post_process(generator, tokenizer, past_kv, draft_past_kv)

        return generator, tokenizer, past_kv, draft_past_kv

    def build(self):
        """
        Build the full generation pipeline from scratch.
        """
        model, draft_model, tokenizer = self.build_models_and_tokenizer()
        return self.build_generator_pipeline(model, draft_model, tokenizer)


if __name__ == "__main__":
    run_app(GeneratorPipelineBuilder())
