from typing import Optional, Dict, Any, List
from dataclasses import replace
import numpy as np
import torch
import decode

from generator.base_generator import BaseGenerator, T2IRequest, T2IResult

from torch.nn import functional as F
from llamagen.autoregressive.models.generate import top_k_top_p_filtering

class LlamaGenGenerator(BaseGenerator):
    def __init__(self, gpt, vq_model, latent_size, max_batch_size, device):
        super().__init__(device=device)
        self.gpt = gpt
        self.vq_model = vq_model
        self.max_batch_size = max_batch_size
        self.qz_shape = [1, 8, latent_size, latent_size]
        
    def _assert_ready(self, req: T2IRequest):
        assert req.top_k is not None and req.top_p is not None
    
    def _pad_after_first_quarter(self, tokens: torch.Tensor, q: int, mode: str):
        B, N = tokens.shape
        if mode == "zero":
            tokens[:, q:] = 0
        elif mode == "mean":
            mean_vals = tokens[:, :q].float().mean(dim=1, keepdim=True)
            tokens[:, q:] = mean_vals.long()
        elif mode == "repeat":
            tokens[:, q:2*q] = tokens[:, :q]
            tokens[:, 2*q:3*q] = tokens[:, :q]
            tokens[:, 3*q:4*q] = tokens[:, :q]
        else:
            raise RuntimeError(f"Set wrong padding mode: {mode}")
        
        return tokens
        
    def _pad_after_second_quarter(self, tokens: torch.Tensor, h: int, mode: str):
        B, N = tokens.shape
        if mode == "zero":
            tokens[:, h:] = 0
        elif mode == "mean":
            mean_vals = tokens[:, :h].float().mean(dim=1, keepdim=True)
            tokens[:, h:] = mean_vals.long()
        elif mode == "repeat":
            tokens[:, h:] = tokens[:, :h]
        else:
            raise RuntimeError(f"Set wrong padding mode: {mode}")
        
        return tokens

    def _pad_after_third_quarter(self, tokens: torch.Tensor, q: int, mode: str):
        B, N = tokens.shape
        t = 3 * q
        if mode == "zero":
            tokens[:, t:] = 0
        elif mode == "mean":
            mean_vals = tokens[:, :t].float().mean(dim=1, keepdim=True)
            tokens[:, t:] = mean_vals.long()
        elif mode == "repeat":
            tokens[:, t:] = tokens[:, :q]
        else:
            raise RuntimeError(f"Set wrong padding mode: {mode}")
            
        return tokens
    
    def _use_3way(self, req: T2IRequest, stage: str) -> bool:
        cc = getattr(req.cfg, "three_way_cfg", None)
        
        if cc is None:
            return False
        
        v = getattr(cc, stage, False)
        
        print(f"Use 3 way for stage{stage}: {v}")
        return v
    
    @torch.inference_mode()
    def _sample(self, logit, temperature, top_k, top_p):
        logit = logit / max(temperature, 1e-8)
        
        if top_k > 0 or top_p < 1.0:
            logit = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p)
            
        probs = F.softmax(logit, dim=-1)
        nxt   = torch.multinomial(probs, num_samples=1)
        
        return nxt
        
    
    @torch.inference_mode()
    def _prefill_2_way(self, inputs_embeds: torch.Tensor, input_pos: torch.Tensor, df, temperature: float = 1.0, top_k: int = None, top_p: float=1.0, scale: float=1.0) -> torch.Tensor:
        logits, _ = self.gpt(idx=None, cond_idx=inputs_embeds, input_pos=input_pos)
        if logits.dim() == 3:
            logits = logits[:, -1, :]
        
        uncond = logits[0::2, :]
        cond   = logits[1::2, :]
        
        fused = df(logit_cond=cond, logit_uncond=uncond, scale=float(scale))
        
        return self._sample(fused, temperature=temperature, top_k=top_k, top_p=top_p)
        
    @torch.inference_mode()
    def _prefill_3_way(self, inputs_embeds: torch.Tensor, input_pos: torch.Tensor, df, temperature: float = 1.0, top_k: int = None, top_p: float=1.0, s1: float = 1.0, s2: float = 1.0) -> torch.Tensor:
        logits, _ = self.gpt(idx=None, cond_idx=inputs_embeds, input_pos=input_pos)
        if logits.dim() == 3:
            logits = logits[:, -1, :]
    
        uncond = logits[0::3, :]
        cond   = logits[1::3, :]
        cond_m = logits[2::3, :]
        
        fused = df(
            logit_cond_modified=cond_m,
            logit_cond=cond,
            logit_uncond=uncond,
            cfg_scale_1=s1,
            cfg_scale_2=s2,
        )
        
        return self._sample(fused, temperature=temperature, top_k=top_k, top_p=top_p)
        
    @torch.inference_mode()
    def _decode_n_tokens(self, cur_token: torch.Tensor, input_pos: torch.Tensor, steps: int, df, temperature: float, top_k: int, top_p: float, scale: float) -> List[torch.Tensor]:
        new_tokens = []
        for i in range(steps):
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                cur_token_combined = cur_token.repeat_interleave(2, 0)
                logits, _ = self.gpt(cur_token_combined, cond_idx=None, input_pos=input_pos)
                if logits.dim() == 3:
                    logits = logits[:, -1, :]
            
                uncond = logits[0::2, :]
                cond   = logits[1::2, :]
                fused = df(logit_cond=cond, logit_uncond=uncond, scale=float(scale))
                
                next_token = self._sample(fused, temperature=temperature, top_k=top_k, top_p=top_p)
                
                input_pos += 1
                new_tokens.append(next_token.clone())
                cur_token = next_token.view(-1, 1)
    
        return new_tokens
        
    @torch.inference_mode()
    def _decode_n_tokens_3_way(self, cur_token: torch.Tensor, input_pos: torch.Tensor, steps: int, df, temperature: float, top_k: int, top_p: float, s1: float, s2: float) -> List[torch.Tensor]:
        new_tokens = []
        for i in range(steps):
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                cur_token_combined = cur_token.repeat_interleave(3, 0)
                logits, _ = self.gpt(cur_token_combined, cond_idx=None, input_pos=input_pos)
                if logits.dim() == 3:
                    logits = logits[:, -1, :]
            
                uncond = logits[0::3, :]
                cond   = logits[1::3, :]
                cond_m = logits[2::3, :]
                
                fused = df(
                    logit_cond_modified=cond_m,
                    logit_cond=cond,
                    logit_uncond=uncond,
                    cfg_scale_1=s1,
                    cfg_scale_2=s2,
                )
                
                next_token = self._sample(fused, temperature=temperature, top_k=top_k, top_p=top_p)
                
                input_pos += 1
                new_tokens.append(next_token.clone())
                cur_token = next_token.view(-1, 1)
    
        return new_tokens
    
    @torch.inference_mode()
    def _chain_2way_core(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int) -> torch.Tensor:
        B2 = req.inputs_embeds.size(0); assert B2 % 2 == 0, "2-way expects batch multiple of 2"
        B  = B2 // 2
        device = req.inputs_embeds.device
        df    = getattr(decode, req.cfg.decode_func)
        scale = req.cfg.cfg_scale
        temp  = req.temperature
        top_k = req.top_k
        top_p = req.top_p
        
        T = req.inputs_embeds.size(1)
        
        if given_tokens is not None:
            T_generated = given_tokens.size(1)
            T_max = T + T_generated + steps
            given_pair_tok  = given_tokens.repeat_interleave(2, 0)
            input_pos_full = torch.arange(0, T + T_generated - 1, device=device, dtype=torch.int)
        else:
            T_max = T + steps
        
        
        with torch.device(device):
            self.gpt.setup_caches(max_batch_size=B2, max_seq_length=T_max, dtype=self.gpt.tok_embeddings.weight.dtype)
            
        prefix_mask_bool = (req.attention_mask[:, :T] != 0).unsqueeze(1)
        self.gpt.causal_mask[:, :, :T] &= prefix_mask_bool
        
        S = self.gpt.causal_mask.size(1)
        self.gpt.causal_mask |= torch.eye(S, S, device=device, dtype=torch.bool)
        
        new_tokens = torch.zeros((B, steps), dtype=torch.long, device=device)
        if given_tokens is not None:
            idx = given_pair_tok[:, :-1]
            _ = self.gpt(idx=idx, cond_idx=req.inputs_embeds, input_pos=input_pos_full)
            start_token = given_tokens[:, -1:].contiguous().long()
            input_pos   = torch.tensor([T + T_generated - 1], device=device, dtype=torch.int)
            
            n_tokens = self._decode_n_tokens(
                start_token,
                input_pos,
                steps,
                df,
                temperature=temp,
                top_k=top_k,
                top_p=top_p,
                scale=float(scale),
            )
            new_tokens[:, :] = torch.cat(n_tokens, dim=1).long()
        else:
            input_pos = torch.arange(0, T, device=device)
            next_token = self._prefill_2_way(req.inputs_embeds, input_pos, df, temperature=temp, top_k=top_k, top_p=top_p, scale=scale)
            new_tokens[:, 0:1] = next_token

            input_pos = torch.tensor([T], device=device, dtype=torch.int)

            n_tokens = self._decode_n_tokens( 
                next_token, 
                input_pos, 
                steps-1,
                df,
                temperature=temp, 
                top_k=top_k, 
                top_p=top_p,
                scale=float(scale),
            )
            
            new_tokens[:, 1:] = torch.cat(n_tokens, dim=1)

        return new_tokens
        
    @torch.inference_mode()
    def _chain_3way_core(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int, s1: float, s2: float) -> torch.Tensor:
        B3 = req.inputs_embeds.size(0); assert B3 % 3 == 0, "3-way expects batch multiple of 3"
        B  = B3 // 3
        device = req.inputs_embeds.device
        df    = getattr(decode, req.cfg.three_way_cfg.decode_func)
        temp  = req.temperature
        top_k = req.top_k
        top_p = req.top_p
        
        T = req.inputs_embeds.size(1)
        
        if given_tokens is not None:
            T_generated = given_tokens.size(1)
            T_max = T + T_generated + steps
            given_pair_tok  = given_tokens.repeat_interleave(3, 0)
            input_pos_full = torch.arange(0, T + T_generated - 1, device=device, dtype=torch.int)
        else:
            T_max = T + steps
        
        
        with torch.device(device):
            self.gpt.setup_caches(max_batch_size=B3, max_seq_length=T_max, dtype=self.gpt.tok_embeddings.weight.dtype)
            
        prefix_mask_bool = (req.attention_mask[:, :T] != 0).unsqueeze(1)
        self.gpt.causal_mask[:, :, :T] &= prefix_mask_bool
        
        S = self.gpt.causal_mask.size(1)
        self.gpt.causal_mask |= torch.eye(S, S, device=device, dtype=torch.bool)
        
        new_tokens = torch.zeros((B, steps), dtype=torch.long, device=device)
        if given_tokens is not None:
            idx = given_pair_tok[:, :-1]
            _ = self.gpt(idx=idx, cond_idx=req.inputs_embeds, input_pos=input_pos_full)
            start_token = given_tokens[:, -1:].contiguous().long()
            input_pos   = torch.tensor([T + T_generated - 1], device=device, dtype=torch.int)
            
            n_tokens = self._decode_n_tokens_3_way(
                start_token,
                input_pos,
                steps,
                df,
                temperature=temp,
                top_k=top_k,
                top_p=top_p,
                s1=s1,
                s2=s2,
            )
            new_tokens[:, :] = torch.cat(n_tokens, dim=1).long()
        else:
            input_pos = torch.arange(0, T, device=device)
            next_token = self._prefill_3_way(req.inputs_embeds, input_pos, df, temperature=temp, top_k=top_k, top_p=top_p, s1=s1, s2=s2)
            new_tokens[:, 0:1] = next_token

            input_pos = torch.tensor([T], device=device, dtype=torch.int)

            n_tokens = self._decode_n_tokens_3_way( 
                next_token, 
                input_pos, 
                steps-1,
                df, 
                temperature=temp, 
                top_k=top_k, 
                top_p=top_p,
                s1=s1,
                s2=s2,
            )
            
            new_tokens[:, 1:] = torch.cat(n_tokens, dim=1)

        return new_tokens
    
    @torch.inference_mode()
    def _chain_2way(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int) -> torch.Tensor:
        B2 = req.inputs_embeds.size(0); assert B2 % 2 == 0, "2-way expects batch multiple of 2"
        B  = B2 // 2
        max_bs = getattr(self, "max_batch_size", None)

        if not max_bs or B <= max_bs:
            return self._chain_2way_core(req, given_tokens, steps)

        parts = []
        for b0 in range(0, B, max_bs):
            b1 = min(B, b0 + max_bs)
            pair_slice = slice(2 * b0, 2 * b1)
            sub_req = replace(
                req,
                inputs_embeds=req.inputs_embeds[pair_slice],
                attention_mask=(req.attention_mask[pair_slice] if req.attention_mask is not None else None),
            )
            sub_given = (given_tokens[b0:b1] if given_tokens is not None else None)
            tokens_sub = self._chain_2way_core(sub_req, sub_given, steps)
            parts.append(tokens_sub)
        return torch.cat(parts, dim=0)

    @torch.inference_mode()
    def _chain_3way(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int, s1: float, s2: float) -> torch.Tensor:
        B3 = req.inputs_embeds.size(0); assert B3 % 3 == 0, "3-way expects batch multiple of 3"
        B  = B3 // 3
        max_bs = getattr(self, "max_batch_size", None)

        if not max_bs or B <= max_bs:
            return self._chain_3way_core(req, given_tokens, steps, s1, s2)

        parts = []
        for b0 in range(0, B, max_bs):
            b1 = min(B, b0 + max_bs)
            pair_slice = slice(3 * b0, 3 * b1)
            sub_req = replace(
                req,
                inputs_embeds=req.inputs_embeds[pair_slice],
                attention_mask=(req.attention_mask[pair_slice] if req.attention_mask is not None else None),
            )
            sub_given = (given_tokens[b0:b1] if given_tokens is not None else None)
            tokens_sub = self._chain_3way_core(sub_req, sub_given, steps, s1, s2)
            parts.append(tokens_sub)
        return torch.cat(parts, dim=0)
    
    @torch.inference_mode()
    def generate_t2i_first_quarter(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        
        q1 = self._chain_2way(req, given_tokens=None, steps=q)
          
        out = torch.zeros((q1.size(0), N), dtype=torch.long, device=q1.device)
        out[:, :q] = q1
        out = self._pad_after_first_quarter(out, q, req.pad)
        return out
    
    @torch.inference_mode()
    def generate_t2i_second_quarter(self, req: T2IRequest, gen_tokens_q1: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        h = N // 2
        
        past_q1 = gen_tokens_q1[:, :q].contiguous()
        
        if self._use_3way(req, "second_quarter"):
            s1 = req.cfg.three_way_cfg.second_quarter_scale_1
            s2 = req.cfg.three_way_cfg.second_quarter_scale_2
            q2 = self._chain_3way(req, given_tokens=past_q1, steps=q, s1=s1, s2=s2)          # (B, q)
        else:
            q2 = self._chain_2way(req, given_tokens=past_q1, steps=q)          # (B, q)
        out = torch.zeros((past_q1.size(0), N), dtype=torch.long, device=past_q1.device)
        out[:, :q]  = past_q1
        out[:, q:h] = q2
        
        out = self._pad_after_second_quarter(out, h, mode=req.pad)
        return out
    
    @torch.inference_mode()
    def generate_t2i_third_quarter(self, req: T2IRequest, gen_tokens_half: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        h = N // 2
        t = 3 * q
        
        past_half = gen_tokens_half[:, :h].contiguous()
        
        if self._use_3way(req, "third_quarter"):
            s1 = req.cfg.three_way_cfg.third_quarter_scale_1
            s2 = req.cfg.three_way_cfg.third_quarter_scale_2
            q3 = self._chain_3way(req, given_tokens=past_half, steps=q, s1=s1, s2=s2)        # (B, q)
        else:
            q3 = self._chain_2way(req, given_tokens=past_half, steps=q)        # (B, q)
        out = torch.zeros((past_half.size(0), N), dtype=torch.long, device=past_half.device)
        out[:, :h]  = past_half
        out[:, h:t] = q3
        out = self._pad_after_third_quarter(out, q, mode=req.pad)
        return out
    
    @torch.inference_mode()
    def generate_t2i_fourth_quarter(self, req: T2IRequest, gen_tokens_3q: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        t = 3 * q
        
        past_3q = gen_tokens_3q[:, :t].contiguous()
        
        if self._use_3way(req, "fourth_quarter"):
            s1 = req.cfg.three_way_cfg.fourth_quarter_scale_1
            s2 = req.cfg.three_way_cfg.fourth_quarter_scale_2
            q4 = self._chain_3way(req, given_tokens=past_3q, steps=q, s1=s1, s2=s2)
        else:
            q4 = self._chain_2way(req, given_tokens=past_3q, steps=q)
        out = torch.zeros((past_3q.size(0), N), dtype=torch.long, device=past_3q.device)
        out[:, :t] = past_3q
        out[:, t:] = q4
        
        return out
    
    @torch.inference_mode()
    def generate_t2i_first_half(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        h = N // 2
        
        half = self._chain_2way(req, given_tokens=None, steps=h)
        
        out = torch.zeros((half.size(0), N), dtype=torch.long, device=half.device)
        out[:, :h] = half
        out = self._pad_after_second_quarter(out, h, req.pad)
        return out

    @torch.inference_mode()
    def generate_t2i_second_half(self, req: T2IRequest, gen_tokens_half: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        h = N // 2
        
        past_half = gen_tokens_half[:, :h].contiguous()
        if self._use_3way(req, "second_half"):
            s1 = req.cfg.three_way_cfg.second_half_scale_1
            s2 = req.cfg.three_way_cfg.second_half_scale_2
            rest = self._chain_3way(req, given_tokens=past_half, steps=N - h, s1=s1, s2=s2)  # (B, h)
        else:
            rest = self._chain_2way(req, given_tokens=past_half, steps=N - h)  # (B, h)
        out = torch.zeros((past_half.size(0), N), dtype=torch.long, device=past_half.device)
        out[:, :h] = past_half
        out[:, h:] = rest
        
        return out
    
    @torch.inference_mode()
    def generate_t2i(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        tokens = self._chain_2way(req, given_tokens=None, steps=N)  # (B, N)
        return tokens
    
    @torch.inference_mode()
    def generate_t2i_same_setup(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        q1 = self.generate_t2i_first_quarter(req)
        q2 = self.generate_t2i_second_quarter(req, gen_tokens_q1=q1)
        q4 = self.generate_t2i_second_half(req, gen_tokens_half=q2)

        return q4
       
    @torch.inference_mode() 
    def image_decode(self, req: T2IRequest, token_ids: torch.Tensor) -> np.ndarray:            
        dec = np.zeros((token_ids.size(0), req.img_size, req.img_size, 3), dtype=np.uint8)
        
        for i, image_tokens in enumerate(token_ids):
            img = self.vq_model.decode_code([image_tokens], self.qz_shape)[0]
            img = img.to(torch.float32).cpu().numpy().transpose(1, 2, 0)
                    
            norm_img = (img + 1) / 2 * 255
            norm_img = norm_img.clip(0, 255)

            dec[i] = norm_img

        return dec