import os
import time
import torch
import torch.nn.functional as F
from torch import device, nn

from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.inputs import EmbedsPrompt

def _now():
    # Ensure correct timing on GPU
    if device.type == "cuda":
        torch.cuda.synchronize()
    return time.perf_counter()

class CoDiLAModel(nn.Module):
    """
    Diffusion Conditioned AR Model
    """
    def __init__(self, diffusion_model: nn.Module, ar_model: nn.Module, tokenizer, generator, block_size):
        super().__init__()
        # --- Sub-models ---
        self.diffusion_model = diffusion_model
        self.ar_model = ar_model
        self.tokenizer = tokenizer
        self.generator = generator
        self.block_size = block_size
        self.ar_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)

        # After resize, ensure AR embeddings are also bf16
        proj_dtype = diffusion_model.dtype
        self.ar_model.get_input_embeddings().to(proj_dtype)
        
        self.bot = 151667 # <think>
        self.eot = 151668 # </think>

        self.bot_embed = self.ar_model.get_input_embeddings()(torch.tensor([self.bot], device="cuda")) 
        self.eot_embed = self.ar_model.get_input_embeddings()(torch.tensor([self.eot], device="cuda"))
        self.eos_embed = self.ar_model.get_input_embeddings()(torch.tensor([151643], device="cuda"))

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def generate(self, input_ids, block_ranges, **kwargs):
        """
        Generate refined text for masked blocks using diffusion + AR synergy.
        Args:
            input_ids: (B, L) original input IDs
            block_ranges: list of lists of (start, end) tuples for each block in each sample
            max_new_tokens: maximum tokens AR model can generate per block
        Returns:
            List of generated strings per block
        """
        B, L = input_ids.shape
        if B > 1:
            raise Exception("Batch generation not supported yet.")

        t_total0 = _now()

        # =======================================================================
        # 1. Diffusion forward (Dream model)
        # =======================================================================
        with torch.inference_mode():
            t_diff0 = _now()
            diff_out = self.diffusion_model(
                input_ids=input_ids,
                use_cache=True,
            )
            diff_logits = diff_out.logits   # (B, L, vocab)

            # Soft embedding merging
            t_diff1 = _now()
            diff_probs = F.softmax(diff_logits, dim=-1)
            
            # 1. Get the target vocab size from your AR model
            target_vocab_size = self.ar_model.get_input_embeddings().weight.shape[0]

            # 2. Slice the probabilities to match that EXACT size
            projected = diff_probs[:, :, :target_vocab_size] @ self.ar_model.get_input_embeddings().weight

        # -----------------------------------------------------------
        # 2. Collect all masked blocks into a list
        # -----------------------------------------------------------
        all_prompt_embeds = []

        spans = block_ranges[0] 

        for (left, right) in spans:
            # Slice embeddings
            block_embeds = projected[0, left-1:right-1, :].clone() 

            # Masking logic
            block_argmax = diff_logits[0, left-1:right-1, :].argmax(dim=-1)
            mask = block_argmax == self.tokenizer.eos_token_id
            if mask.any():
                block_embeds[mask] = self.eos_embed

            # Cat with pre-cached embeddings
            prompt_embeds = torch.cat([self.bot_embed, block_embeds, self.eot_embed], dim=0)
            all_prompt_embeds.append(EmbedsPrompt(prompt_embeds=prompt_embeds))
    
        
        # -----------------------------------------------------------
        # 3. AR Generation for ALL blocks
        # -----------------------------------------------------------
        sampling_params = SamplingParams(
            temperature=0.1,
            max_tokens=self.block_size,
            logprobs=20,
            top_p=0.8,
            top_k=-1,
            min_p=0.0,
        )

        t_ar0 = _now()
        gen_out = self.generator.generate(
            all_prompt_embeds,
            sampling_params=sampling_params,
        )
        t_ar1 = _now()
        
        result_ids = []
        result_entropies = []

        # 2. FIXED ENTROPY CALCULATION
        for out in gen_out:
            seq = out.outputs[0]
            token_ids = torch.tensor(seq.token_ids, device="cpu")
            result_ids.append(token_ids)

            step_entropies = []
            
            # vLLM returns logprobs as a list (steps) of dicts {id: LogprobObj}
            if seq.logprobs is not None:
                for step_logprobs in seq.logprobs:
                    logps = torch.tensor(
                        [lp.logprob for lp in step_logprobs.values()],
                        dtype=torch.float32
                    )

                    probs = torch.exp(logps)
                    probs = probs / (probs.sum() + 1e-9)

                    entropy = -(probs * torch.log(probs + 1e-9)).sum()
                    step_entropies.append(entropy.item())

                avg_entropy = float(sum(step_entropies) / len(step_entropies)) if step_entropies else 0.0
            else:
                # If this hits, logprobs weren't enabled in SamplingParams
                raise Exception("Logprobs not found in generation output. Ensure SamplingParams.logprobs > 0.")

            result_entropies.append(avg_entropy)

        t_total1 = _now()

        timings = {
            "diffusion_ms": (t_diff1 - t_diff0) * 1000.0,
            "ar_ms": (t_ar1 - t_ar0) * 1000.0,
            "total_ms": (t_total1 - t_total0) * 1000.0,
        }

        print(
            f"[DiCA.generate] diffusion={timings['diffusion_ms']:.2f}ms | "
            f"ar={timings['ar_ms']:.2f}ms | total={timings['total_ms']:.2f}ms",
            flush=True
        )

        return result_ids, result_entropies

    def generate_lowest_entropy(self, input_ids_tensor, block_ranges, prompt_len):
        """
        Block-wise filling: unmask blocks in order of lowest entropy (most confident) first.
        """
        torch.cuda.empty_cache()
        device = input_ids_tensor.device

        # Work with mutable Python list
        input_ids = input_ids_tensor[0].tolist()

        # Copy ranges
        remaining = list(block_ranges[0])   # [(s,e), (s,e), ...]

        while remaining:
            # Run model on all remaining blocks
            gen_ids_list, entropies = self.generate(
                input_ids=torch.tensor([input_ids], device=device),
                block_ranges=[remaining[:10]]
            )

            # Find the block with lowest entropy (most confident prediction)
            min_entropy_idx = entropies.index(min(entropies))
            
            # Get the predicted tokens for the lowest entropy block
            predicted_block = gen_ids_list[min_entropy_idx]

            # Get the span for this block
            s, e = remaining[min_entropy_idx]

            # Fill entire block
            for i, tok in enumerate(predicted_block):
                input_ids[s + i] = int(tok)

            # Remove this block from remaining
            remaining.pop(min_entropy_idx)

        # Return decoded continuation after prompt_len
        return self.tokenizer.decode(input_ids[prompt_len:], skip_special_tokens=True)

    @classmethod
    def from_pretrained(cls, load_directory, block_size, **kwargs):
        ar_path = os.path.join(load_directory, "ar_model")

        if "trust_remote_code" not in kwargs:
            kwargs["trust_remote_code"] = True

        # 1. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            "Dream-org/Dream-Coder-v0-Instruct-7B",
            **kwargs,
        )        

        # 2. Load Diffusion and AR models FIRST (Standard Transformers)
        print("Loading Diffusion Model...")
        diffusion_model = AutoModel.from_pretrained(
            "Dream-org/Dream-Coder-v0-Instruct-7B", 
            **kwargs
        )

        print("Loading AR Model...")
        # 3. Define the configuration but DO NOT start the LLM yet
        generator_cfg = dict(
            model=ar_path,
            tokenizer="Qwen/Qwen3-0.6B",
            dtype="bfloat16",
            trust_remote_code=True,
            max_model_len=2048,
            gpu_memory_utilization=0.3,   # Keep this low (0.2-0.3) so 7B model fits
            enable_prompt_embeds=True,   
        )

        ar_model = AutoModelForCausalLM.from_pretrained(
            ar_path, 
            **kwargs
        )

        generator = LLM(**generator_cfg)

        model = cls(diffusion_model, ar_model, tokenizer, generator, block_size)
        
        # 5. Generator will be initialized via init_generator() called in your main script
        return model