from dataclasses import dataclass, field
from typing import Optional, Dict, Any
import numpy as np
import torch
import abc

@dataclass
class T2IRequest:
    inputs_embeds: torch.Tensor                    # (B, T_text, C)
    attention_mask: Optional[torch.Tensor] = None  # (B, T_text) or None
    image_token_num_per_image: int = 576
    img_size: int = 384
    patch_size: int = 16
    temperature: float = 1.0
    top_k: Optional[int] = None
    top_p: Optional[float] = None
    pad: str = "repeat"                              # "zero" | "mean" | "repeat"
    cfg: Optional[Any] = None

@dataclass
class T2IResult:
    images_tokens: torch.Tensor
    images_uint8: Optional[np.ndarray] = None

@dataclass
class T2IRequest:
    """
    inputs_embeds: (B*pair, T_text, C) — 2-way면 pair=2(uncond/cond), 3-way면 pair=3
    attention_mask: (B*pair, T_text) or None
    """
    inputs_embeds: torch.Tensor
    attention_mask: Optional[torch.Tensor] = None
    image_token_num_per_image: int = 576
    img_size: int = 384
    patch_size: int = 16
    temperature: float = 1.0
    top_k: Optional[int] = None
    top_p: Optional[float] = None
    pad: str = "repeat"  # "zero" | "mean" | "repeat"
    cfg: Optional[Any] = None


class BaseGenerator(abc.ABC):
    def __init__(self, device: str = "cuda"):
        self.device = device

    # -------- Full AR generation (tokens only) --------
    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i(self, req: T2IRequest) -> torch.Tensor:
        """Generate all N tokens at once via AR; return (B, N) token ids."""
        raise NotImplementedError
    
    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_same_setup(self, req: T2IRequest) -> torch.Tensor:
        """Generate all N tokens at once via AR; return (B, N) token ids."""
        raise NotImplementedError

    # -------- Quarter-stage --------
    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_first_quarter(self, req: T2IRequest) -> torch.Tensor:
        """Return (B, N) where only Q1 is generated and the rest is padded per req.pad."""
        raise NotImplementedError

    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_second_quarter(self, req: T2IRequest, gen_tokens_q1: torch.Tensor) -> torch.Tensor:
        """Given Q1, generate Q2; return (B, N) with top half filled/padded."""
        raise NotImplementedError

    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_third_quarter(self, req: T2IRequest, gen_tokens_half: torch.Tensor) -> torch.Tensor:
        """Given first half, generate Q3; return (B, N) with 3/4 filled/padded."""
        raise NotImplementedError

    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_fourth_quarter(self, req: T2IRequest, gen_tokens_3q: torch.Tensor) -> torch.Tensor:
        """Given first 3/4, generate Q4; return (B, N) full."""
        raise NotImplementedError

    # -------- Half-stage --------
    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_first_half(self, req: T2IRequest) -> torch.Tensor:
        """Generate first half (Q1→Q2) in one chain; return (B, N) with top half filled/padded."""
        raise NotImplementedError

    @abc.abstractmethod
    @torch.inference_mode()
    def generate_t2i_second_half(self, req: T2IRequest, gen_tokens_half: torch.Tensor) -> torch.Tensor:
        """Given first half, generate second half (Q3→Q4) in one chain; return (B, N) full."""
        raise NotImplementedError

    # -------- Decode --------
    @abc.abstractmethod
    @torch.inference_mode()
    def image_decode(self, req: T2IRequest, token_ids: torch.Tensor) -> np.ndarray:
        """Decode (B, N) tokens to images; return numpy ndarray."""
        raise NotImplementedError