from typing import Optional, Dict, Any
from dataclasses import replace
import numpy as np
import torch
import decode

from generator.base_generator import BaseGenerator, T2IRequest, T2IResult

class JanusGenerator(BaseGenerator):
    def __init__(self, mmgpt, max_batch_size):
        device = mmgpt.device
        super().__init__(device=device)
        self.mmgpt = mmgpt
        self.max_batch_size = max_batch_size
        
    def _assert_ready(self, req: T2IRequest):
        assert req.top_k is None and req.top_p is None, \
            "JanusGenerator: top_k/top_p are not included"
            
    def _pad_after_first_quarter(self, tokens: torch.Tensor, q: int, mode: str):
        B, N = tokens.shape
        if mode == "zero":
            tokens[:, q:] = 0
        elif mode == "mean":
            mean_vals = tokens[:, :q].float().mean(dim=1, keepdim=True)
            tokens[:, q:] = mean_vals.long()
        elif mode == "repeat":
            tokens[:, q:2*q] = tokens[:, :q]
            tokens[:, 2*q:3*q] = tokens[:, :q]
            tokens[:, 3*q:4*q] = tokens[:, :q]
        else:
            raise RuntimeError(f"Set wrong padding mode: {mode}")
        
        return tokens
        
    def _pad_after_second_quarter(self, tokens: torch.Tensor, h: int, mode: str):
        B, N = tokens.shape
        if mode == "zero":
            tokens[:, h:] = 0
        elif mode == "mean":
            mean_vals = tokens[:, :h].float().mean(dim=1, keepdim=True)
            tokens[:, h:] = mean_vals.long()
        elif mode == "repeat":
            tokens[:, h:] = tokens[:, :h]
        else:
            raise RuntimeError(f"Set wrong padding mode: {mode}")
        
        return tokens

    def _pad_after_third_quarter(self, tokens: torch.Tensor, q: int, mode: str):
        B, N = tokens.shape
        t = 3 * q
        if mode == "zero":
            tokens[:, t:] = 0
        elif mode == "mean":
            mean_vals = tokens[:, :t].float().mean(dim=1, keepdim=True)
            tokens[:, t:] = mean_vals.long()
        elif mode == "repeat":
            tokens[:, t:] = tokens[:, :q]
        else:
            raise RuntimeError(f"Set wrong padding mode: {mode}")
            
        return tokens
    
    def _use_3way(self, req: T2IRequest, stage: str) -> bool:
        cc = getattr(req.cfg, "three_way_cfg", None)
        
        if cc is None:
            return False
        
        v = getattr(cc, stage, False)
        
        print(f"Use 3 way for stage{stage}: {v}")
        return v
    
    def _prefill_chain(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], pair_dim: int):
        inputs_embeds  = req.inputs_embeds
        attention_mask = req.attention_mask
        if given_tokens is not None:
            given_pair_tok  = given_tokens.repeat_interleave(pair_dim, 0)          # (B*pair, L)
            given_pair_emb  = self.mmgpt.prepare_gen_img_embeds(given_pair_tok)   # (B*pair, L, H)
            given_pair_mask = torch.ones_like(given_pair_emb[..., 0], dtype=attention_mask.dtype)
            row_prompt = torch.cat([inputs_embeds, given_pair_emb], dim=1)
            row_mask   = torch.cat([attention_mask, given_pair_mask],  dim=1)
        else:
            row_prompt = inputs_embeds
            row_mask   = attention_mask

        outs = self.mmgpt.language_model.model(
            inputs_embeds=row_prompt,
            attention_mask=row_mask,
            use_cache=True,
            past_key_values=None,
            output_hidden_states=False,
        )
        return outs, row_mask
    
    def _append_step(self, row_mask: torch.Tensor, past_kv, pair_emb: torch.Tensor):
        new_mask = torch.cat(
            [row_mask,
             torch.ones((row_mask.size(0), pair_emb.size(1)),
                        dtype=row_mask.dtype, device=row_mask.device)],
            dim=1,
        )
        outs = self.mmgpt.language_model.model(
            inputs_embeds=pair_emb,
            attention_mask=new_mask,
            use_cache=True,
            past_key_values=past_kv,
            output_hidden_states=False,
        )
        return outs, new_mask
    
    @torch.inference_mode()
    def _chain_2way_core(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int) -> torch.Tensor:
        B2 = req.inputs_embeds.size(0); assert B2 % 2 == 0, "2-way expects batch multiple of 2"
        B  = B2 // 2
        device = req.inputs_embeds.device
        df    = getattr(decode, req.cfg.decode_func)
        scale = req.cfg.cfg_scale
        temp  = req.temperature

        outs, row_mask = self._prefill_chain(req, given_tokens, pair_dim=2)
        past_kv = outs.past_key_values
        
        new_tokens = torch.zeros((B, steps), dtype=torch.long, device=device)
        for step in range(steps):
            last_logits = self.mmgpt.gen_head(outs.last_hidden_state[:, -1, :])  # (B*2, V)
            uncond, cond = last_logits[0::2, :], last_logits[1::2, :]
            fused = df(logit_cond=cond, logit_uncond=uncond, scale=scale)
            probs = torch.softmax(fused / temp, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)                 # (B,1)
            new_tokens[:, step] = next_token.squeeze(-1)

            pair_tok = next_token.repeat_interleave(2, 0).squeeze(-1)
            pair_emb = self.mmgpt.prepare_gen_img_embeds(pair_tok).unsqueeze(1)
            outs, row_mask = self._append_step(row_mask, past_kv, pair_emb)
            past_kv = outs.past_key_values

        return new_tokens

    @torch.inference_mode()
    def _chain_3way_core(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int, s1: float, s2: float) -> torch.Tensor:
        B3 = req.inputs_embeds.size(0); assert B3 % 3 == 0, "3-way expects batch multiple of 3"
        B  = B3 // 3
        device = req.inputs_embeds.device
        df    = getattr(decode, req.cfg.three_way_cfg.decode_func)
        temp  = req.temperature

        outs, row_mask = self._prefill_chain(req, given_tokens, pair_dim=3)
        past_kv = outs.past_key_values

        new_tokens = torch.zeros((B, steps), dtype=torch.long, device=device)
        for step in range(steps):
            last_logits = self.mmgpt.gen_head(outs.last_hidden_state[:, -1, :])  # (B*3, V)
            uncond, cond, cond_m = last_logits[0::3, :], last_logits[1::3, :], last_logits[2::3, :]
            fused = df(
                logit_cond_modified=cond_m,
                logit_cond=cond,
                logit_uncond=uncond,
                cfg_scale_1=s1, cfg_scale_2=s2,
            )
            probs = torch.softmax(fused / temp, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)                 # (B,1)
            new_tokens[:, step] = next_token.squeeze(-1)

            pair_tok = next_token.repeat_interleave(3, 0).squeeze(-1)
            pair_emb = self.mmgpt.prepare_gen_img_embeds(pair_tok).unsqueeze(1)
            outs, row_mask = self._append_step(row_mask, past_kv, pair_emb)
            past_kv = outs.past_key_values
        return new_tokens

    # ---- 래퍼: max_batch_size 기준 분할 ----
    @torch.inference_mode()
    def _chain_2way(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int) -> torch.Tensor:
        B2 = req.inputs_embeds.size(0); assert B2 % 2 == 0, "2-way expects batch multiple of 2"
        B  = B2 // 2
        max_bs = getattr(self, "max_batch_size", None)

        if not max_bs or B <= max_bs:
            return self._chain_2way_core(req, given_tokens, steps)

        parts = []
        for b0 in range(0, B, max_bs):
            b1 = min(B, b0 + max_bs)
            pair_slice = slice(2 * b0, 2 * b1)
            sub_req = replace(
                req,
                inputs_embeds=req.inputs_embeds[pair_slice],
                attention_mask=(req.attention_mask[pair_slice] if req.attention_mask is not None else None),
            )
            sub_given = (given_tokens[b0:b1] if given_tokens is not None else None)
            tokens_sub = self._chain_2way_core(sub_req, sub_given, steps)
            parts.append(tokens_sub)
        return torch.cat(parts, dim=0)

    @torch.inference_mode()
    def _chain_3way(self, req: T2IRequest, given_tokens: Optional[torch.Tensor], steps: int, s1: float, s2: float) -> torch.Tensor:
        B3 = req.inputs_embeds.size(0); assert B3 % 3 == 0, "3-way expects batch multiple of 3"
        B  = B3 // 3
        max_bs = getattr(self, "max_batch_size", None)

        if not max_bs or B <= max_bs:
            return self._chain_3way_core(req, given_tokens, steps, s1, s2)

        parts = []
        for b0 in range(0, B, max_bs):
            b1 = min(B, b0 + max_bs)
            pair_slice = slice(3 * b0, 3 * b1)
            sub_req = replace(
                req,
                inputs_embeds=req.inputs_embeds[pair_slice],
                attention_mask=(req.attention_mask[pair_slice] if req.attention_mask is not None else None),
            )
            sub_given = (given_tokens[b0:b1] if given_tokens is not None else None)
            tokens_sub = self._chain_3way_core(sub_req, sub_given, steps, s1, s2)
            parts.append(tokens_sub)
        return torch.cat(parts, dim=0)
        
    @torch.inference_mode()
    def generate_t2i_first_quarter(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        
        q1 = self._chain_2way(req, given_tokens=None, steps=q)
          
        out = torch.zeros((q1.size(0), N), dtype=torch.long, device=q1.device)
        out[:, :q] = q1
        out = self._pad_after_first_quarter(out, q, req.pad)
        return out
    
    @torch.inference_mode()
    def generate_t2i_second_quarter(self, req: T2IRequest, gen_tokens_q1: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        h = N // 2
        
        past_q1 = gen_tokens_q1[:, :q].contiguous()
        
        if self._use_3way(req, "second_quarter"):
            s1 = req.cfg.three_way_cfg.second_quarter_scale_1
            s2 = req.cfg.three_way_cfg.second_quarter_scale_2
            q2 = self._chain_3way(req, given_tokens=past_q1, steps=q, s1=s1, s2=s2)          # (B, q)
        else:
            q2 = self._chain_2way(req, given_tokens=past_q1, steps=q)          # (B, q)
        out = torch.zeros((past_q1.size(0), N), dtype=torch.long, device=past_q1.device)
        out[:, :q]  = past_q1
        out[:, q:h] = q2
        
        out = self._pad_after_second_quarter(out, h, mode=req.pad)
        return out
    
    @torch.inference_mode()
    def generate_t2i_third_quarter(self, req: T2IRequest, gen_tokens_half: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        h = N // 2
        t = 3 * q
        
        past_half = gen_tokens_half[:, :h].contiguous()
        
        if self._use_3way(req, "third_quarter"):
            s1 = req.cfg.three_way_cfg.third_quarter_scale_1
            s2 = req.cfg.three_way_cfg.third_quarter_scale_2
            q3 = self._chain_3way(req, given_tokens=past_half, steps=q, s1=s1, s2=s2)        # (B, q)
        else:
            q3 = self._chain_2way(req, given_tokens=past_half, steps=q)        # (B, q)
        out = torch.zeros((past_half.size(0), N), dtype=torch.long, device=past_half.device)
        out[:, :h]  = past_half
        out[:, h:t] = q3
        out = self._pad_after_third_quarter(out, q, mode=req.pad)
        return out
    
    @torch.inference_mode()
    def generate_t2i_fourth_quarter(self, req: T2IRequest, gen_tokens_3q: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        q = N // 4
        t = 3 * q
        
        past_3q = gen_tokens_3q[:, :t].contiguous()
        
        if self._use_3way(req, "fourth_quarter"):
            s1 = req.cfg.three_way_cfg.fourth_quarter_scale_1
            s2 = req.cfg.three_way_cfg.fourth_quarter_scale_2
            q4 = self._chain_3way(req, given_tokens=past_3q, steps=q, s1=s1, s2=s2)
        else:
            q4 = self._chain_2way(req, given_tokens=past_3q, steps=q)
        out = torch.zeros((past_3q.size(0), N), dtype=torch.long, device=past_3q.device)
        out[:, :t] = past_3q
        out[:, t:] = q4
        
        return out
    
    @torch.inference_mode()
    def generate_t2i_first_half(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        q1 = self.generate_t2i_first_quarter(req)
        q2 = self.generate_t2i_second_quarter(req, gen_tokens_q1=q1)
        
        return q2

    @torch.inference_mode()
    def generate_t2i_second_half(self, req: T2IRequest, gen_tokens_half: torch.Tensor) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        h = N // 2
        
        past_half = gen_tokens_half[:, :h].contiguous()
        if self._use_3way(req, "second_half"):
            s1 = req.cfg.three_way_cfg.second_half_scale_1
            s2 = req.cfg.three_way_cfg.second_half_scale_2
            rest = self._chain_3way(req, given_tokens=past_half, steps=N - h, s1=s1, s2=s2)  # (B, h)
        else:
            rest = self._chain_2way(req, given_tokens=past_half, steps=N - h)  # (B, h)
        out = torch.zeros((past_half.size(0), N), dtype=torch.long, device=past_half.device)
        out[:, :h] = past_half
        out[:, h:] = rest
        
        return out
    
    @torch.inference_mode()
    def generate_t2i(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        N = int(req.image_token_num_per_image)
        tokens = self._chain_2way(req, given_tokens=None, steps=N)  # (B, N)
        return tokens
    
    @torch.inference_mode()
    def generate_t2i_same_setup(self, req: T2IRequest) -> torch.Tensor:
        self._assert_ready(req)
        q1 = self.generate_t2i_first_quarter(req)
        q2 = self.generate_t2i_second_quarter(req, gen_tokens_q1=q1)
        q4 = self.generate_t2i_second_half(req, gen_tokens_half=q2)

        return q4
       
    @torch.inference_mode() 
    def image_decode(self, req: T2IRequest, token_ids: torch.Tensor) -> np.ndarray:
        dec = self.mmgpt.gen_vision_model.decode_code(token_ids.to(dtype=torch.int), shape=[token_ids.size(0), 8, req.img_size//req.patch_size, req.img_size//req.patch_size])
        
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
        dec = np.clip((dec + 1) / 2 * 255, 0, 255)
        
        return dec