import os
import torch
from torch import nn
import torch.nn.functional as F
from transformers.modeling_outputs import ModelOutput

from transformers import AutoTokenizer


class CoDiLAModel(nn.Module):
    """
    Diffusion Conditioned AR Model
    """
    def __init__(self, diffusion_model: nn.Module, ar_model: nn.Module, tokenizer):
        super().__init__()
        # --- Sub-models ---
        self.diffusion_model = diffusion_model
        self.ar_model = ar_model
        self.tokenizer = tokenizer
        self.ar_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)

        # --- Projection from diffusion hidden → AR embedding space ---
        proj_dtype = diffusion_model.dtype

        # After resize, ensure AR embeddings are also bf16
        self.ar_model.get_input_embeddings().to(proj_dtype)
        self.ar_target_vocab_size = self.ar_model.get_input_embeddings().weight.shape[0] # Usually 49152
        
        self.bot = 151667 # <think>
        self.eot = 151668 # </think>

        # Ensure pad token is defined
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id


    def forward(
        self,
        input_ids, block_ranges, block_mask_indices, labels,
        **kwargs
    ) -> ModelOutput:

        device = input_ids.device
        B, L = input_ids.shape

        # -----------------------------------------------------------
        # 1. Diffusion forward
        # -----------------------------------------------------------
        with torch.no_grad():
            diff_logits = self.diffusion_model(
                input_ids=input_ids,
                output_hidden_states=True
            ).logits  # (B, L, vocab_size)
            
            # Soft embedding merging
            diff_probs = F.softmax(diff_logits, dim=-1)
            
            projected = diff_probs[:, :, :self.ar_target_vocab_size] @ self.ar_model.get_input_embeddings().weight


        if block_ranges is None or block_mask_indices is None:
            raise Exception(f"block_ranges or block_mask_indices is None.")

        # -----------------------------------------------------------
        # 2. Collect all masked blocks into a list
        # -----------------------------------------------------------
        all_prompt_embeds = []
        all_gold_ids = []
        block_meta = []  # (start_idx_in_batch, L_gold)

        bot_embed = self.ar_model.get_input_embeddings()(torch.tensor([self.bot], device=device)) 
        eot_embed = self.ar_model.get_input_embeddings()(torch.tensor([self.eot], device=device))

        for b in range(B):
            spans = block_ranges[b]
            flags = block_mask_indices[b]

            for flag, (left, right) in zip(flags.tolist(), spans):
                if not flag: 
                    continue
                if left >= right:
                    raise Exception("Invalid block span")

                block_embeds = projected[b, left-1:right-1, :]        # (L_block, D)
                gold_ids     = labels[b, left:right]               # (L_gold,)

                gold_embeds = self.ar_model.get_input_embeddings()(gold_ids)
                
                prompt_embeds = torch.cat(
                    [bot_embed, block_embeds, eot_embed, gold_embeds],
                    dim=0
                )
                all_prompt_embeds.append(prompt_embeds)
                all_gold_ids.append(gold_ids)

                # save lengths for unpacking later
                block_meta.append({
                    "L_block": gold_ids.size(0),
                    "L_gold": gold_ids.size(0)
                })

        # -----------------------------------------------------------
        # 3. Pad all prompts to the same length
        # -----------------------------------------------------------
        max_len = max(p.size(0) for p in all_prompt_embeds)

        padded = []
        attn_masks = []
        for p in all_prompt_embeds:
            pad_len = max_len - p.size(0)
            if pad_len > 0:
                pad = torch.zeros(pad_len, p.size(1), device=device)
                p = torch.cat([p, pad], dim=0)
                mask = torch.cat([torch.ones(p.size(0)-pad_len, device=device),
                                torch.zeros(pad_len, device=device)])
            else:
                mask = torch.ones(p.size(0), device=device)
            padded.append(p)
            attn_masks.append(mask)

        padded = torch.stack(padded, dim=0)        # (N_blocks, max_len, D)
        attn_masks = torch.stack(attn_masks, 0)    # (N_blocks, max_len)

        # -----------------------------------------------------------
        # 4. AR forward pass for ALL blocks
        # -----------------------------------------------------------
        ar_out = self.ar_model(
            inputs_embeds=padded,
            attention_mask=attn_masks,
            output_hidden_states=True
        )
        logits = ar_out.logits  # (N_blocks, max_len, V)

        # -----------------------------------------------------------
        # 5. Compute losses per block
        # -----------------------------------------------------------
        all_losses = []

        for i, meta in enumerate(block_meta):
            L_block = meta["L_block"]
            L_gold  = meta["L_gold"]

            prefix_len = 1 + L_block + 1  # BoT + block + EoT

            # slice out gold logits
            shift_logits = logits[i, prefix_len - 1 : prefix_len - 1 + L_gold, :]
            shift_labels = all_gold_ids[i].to(device)

            loss = F.cross_entropy(
                shift_logits.reshape(-1, shift_logits.size(-1)),
                shift_labels.reshape(-1)
            )
            all_losses.append(loss)

        final_loss = torch.stack(all_losses).mean() if all_losses else None

        return ModelOutput(
            loss=final_loss,
        )
    
    @torch.inference_mode()
    def _generate(self, input_ids, block_ranges, max_new_tokens=1, **kwargs):
        """
        Generate method for evaluation callbacks (Only to check everything is working fine, not optimized).
        Args:
            input_ids: (B, L) original input IDs
            block_ranges: list of lists of (start, end) tuples for each block in each sample
            max_new_tokens: maximum tokens AR model can generate per block
        Returns:
            List of generated strings per block
        """
        device = input_ids.device
        B, L = input_ids.shape
        if B > 1:
            raise Exception("Batch generation not supported yet.")

        # =======================================================================
        # 1. Diffusion forward (Dream model)
        # =======================================================================
        diff_logits = self.diffusion_model(
            input_ids=input_ids,
            output_hidden_states=True
        ).logits  # (B, L, vocab_size)

        # Soft embedding merging
        diff_probs = F.softmax(diff_logits, dim=-1)

        # 3. Project to AR embedding space (Soft embedding)
        projected = diff_probs[:, :, :self.ar_target_vocab_size] @ self.ar_model.get_input_embeddings().weight

        # -----------------------------------------------------------
        # 2. Collect all masked blocks into a list
        # -----------------------------------------------------------
        all_prompt_embeds = []

        bot_embed = self.ar_model.get_input_embeddings()(torch.tensor([self.bot], device=device)) 
        eot_embed = self.ar_model.get_input_embeddings()(torch.tensor([self.eot], device=device))

        for b in range(B):
            spans = block_ranges[b]

            for (left, right) in spans:
                if left >= right:
                    raise Exception("Invalid block span")
                
                max_new_tokens = max(max_new_tokens, right - left)

                block_embeds = projected[b, left-1:right-1, :]

                prompt_embeds = torch.cat(
                    [bot_embed, block_embeds, eot_embed],
                    dim=0
                )
                all_prompt_embeds.append(prompt_embeds)

        # -----------------------------------------------------------
        # 3. AR forward pass for ALL blocks
        # -----------------------------------------------------------
        result_ids = []
        result_entropies = []

        for prompt_embeds in all_prompt_embeds:
            prefix_len = prompt_embeds.size(0)

            # Build attention mask
            mask = torch.ones(prefix_len, device=device)

            # Call AR model's generate() using embeddings
            gen_out = self.ar_model.generate(
                inputs_embeds=prompt_embeds.unsqueeze(0),   # (1, T, D)
                attention_mask=mask.unsqueeze(0),           # (1, T)
                max_new_tokens=max_new_tokens,                     # generate same amount as block
                do_sample=False,                            # greedy to match argmax
                pad_token_id=self.tokenizer.pad_token_id,   # explicitly set to silence warning
                output_scores=True,
                return_dict_in_generate=True,
                **kwargs
            )

            gen_ids = gen_out.sequences[0, :max_new_tokens]

            result_ids.append(gen_ids)

            # Compute average entropy
            if gen_out.scores:
                logits = torch.cat(gen_out.scores, dim=0)
                logits = logits[:len(gen_ids)]
                
                probs = F.softmax(logits, dim=-1)
                log_probs = F.log_softmax(logits, dim=-1)
                
                entropy = -(probs * log_probs).sum(dim=-1) # (T,)
                avg_entropy = entropy.mean().item()
            else:
                avg_entropy = 0.0
            
            result_entropies.append(avg_entropy)

        return result_ids, result_entropies


    def dummy_generate(self, input_ids_tensor, block_ranges, prompt_len):
        """
        Suboptimal version just for sanity checking.
        """
        device = input_ids_tensor.device

        # Work with mutable Python list
        input_ids = input_ids_tensor[0].tolist()

        # Copy ranges
        remaining = list(block_ranges[0])   # [(s,e), (s,e), ...]
        block_index = 0

        while remaining:
            # Run model on all remaining blocks
            gen_ids_list, _ = self._generate(
                input_ids=torch.tensor([input_ids], device=device),
                block_ranges=[remaining]
            )

            # gen_ids_list is list of block predictions in same order as remaining[]
            predicted_block = gen_ids_list[block_index]     # <-- full block prediction
            
            # Get the span for this block
            s, e = remaining[0]

            # Fill entire block
            for i, tok in enumerate(predicted_block):
                input_ids[s + i] = int(tok)

            # Remove this block from remaining
            remaining.pop(0)

        # Return decoded continuation after prompt_len
        return self.ar_tokenizer.decode(input_ids[prompt_len:], skip_special_tokens=True)

    # Allow HF Trainer to enable gradient checkpointing
    def gradient_checkpointing_enable(self, **kwargs):
        self.diffusion_model.gradient_checkpointing_enable(**kwargs)
        self.ar_model.gradient_checkpointing_enable(**kwargs)

    def gradient_checkpointing_disable(self):
        self.diffusion_model.gradient_checkpointing_disable()
        self.ar_model.gradient_checkpointing_disable()

    def save_pretrained(self, save_directory):
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        
        print(f"Saving AR model to {save_directory}/ar_model")
        self.ar_model.save_pretrained(os.path.join(save_directory, "ar_model"))