import time
import torch
import torch.nn as nn
from modules_2.stable_audio_tools.models import create_model_from_config
from modules_2.stable_audio_tools.models.autoencoders import AudioAutoencoder, StreamingOobleckDecoder, OobleckDecoder
from modules_2.stable_audio_tools.models.utils import load_ckpt_state_dict
from modules_2.stable_audio_tools.utils.torch_common import copy_state_dict
from modules_2.stable_audio_tools.models.streaming_utils import CUDAGraphed, replace_all_deconvs_with_fast
import torch.nn.utils as nn_utils

def _unwrap_compiled(mod: torch.nn.Module) -> torch.nn.Module:
    return getattr(mod, "_orig_mod", mod)

def _has_streaming_iface(mod: torch.nn.Module) -> bool:
    base = _unwrap_compiled(mod)
    return hasattr(mod, "streaming") or hasattr(base, "streaming")

def strip_weight_norm_streaming_(module: torch.nn.Module):
    for name, m in module.named_modules():
        if hasattr(m, "conv") and hasattr(m.conv, "weight"):
            try:
                nn_utils.remove_weight_norm(m.conv)
            except Exception:
                pass
        if hasattr(m, "convtr") and hasattr(m.convtr, "weight"):
            try:
                nn_utils.remove_weight_norm(m.convtr)
            except Exception:
                pass
    return module


def convert_decoder_to_streaming_inplace(
    ae: "AudioAutoencoder",
    device: str | torch.device = "cuda",
    dtype: torch.dtype | None = None,
    requires_grad: bool = False
):
    dec = ae.decoder
    assert isinstance(dec, OobleckDecoder), "decoder must be OobleckDecoder"
    assert dec.causal, "StreamingOobleckDecoder supports only causal decoders"
    
    sdec = StreamingOobleckDecoder(
        out_channels=dec.out_channels,
        channels=dec.channels,
        latent_dim=dec.latent_dim,
        c_mults=dec.input_c_mults,
        strides=dec.strides,
        use_snake=dec.use_snake,
        antialias_activation=dec.antialias_activation,
        use_nearest_upsample=dec.use_nearest_upsample,
        final_tanh=dec.final_tanh,
        causal=True,
    )
    sdec.initialize_from_conversion(dec)
    strip_weight_norm_streaming_(sdec) 
    if dtype is None:
        try:
            dtype = next(dec.parameters()).dtype
        except StopIteration:
            dtype = torch.float32
    ae.decoder = sdec.to(device=device, dtype=dtype).eval().requires_grad_(requires_grad)



class OnlineCausalWaveformDecoder:
    def __init__(
        self,
        stage1_model: "AudioAutoencoder",
        z_mean: torch.Tensor,   # [1,1,D]
        z_std: torch.Tensor,    # [1,1,D]
        *,
        batch_size: int,
        total_T: int,
        device: torch.device,
        use_cuda_graph: bool = True, 
        dtype_latents: torch.dtype = torch.bfloat16, 
        compile_warmup_steps: int = 3,
    ):
        dec_wrapped = stage1_model.decoder
        dec_base = _unwrap_compiled(dec_wrapped)
        assert _has_streaming_iface(dec_wrapped), \
            "decoder must expose a .streaming(batch_size) interface (StreamingOobleckDecoder or compiled wrapper)."

        self.ae = stage1_model
        self.dec = dec_wrapped 
        self._dec_stream_owner = dec_wrapped if hasattr(dec_wrapped, "streaming") else dec_base
        
        # self.dec: StreamingOobleckDecoder = stage1_model.decoder
        self.dec_dtype = next(self.dec.parameters()).dtype 
        self.lat_dtype = dtype_latents  
        self.z_mean = z_mean.to(device=device, dtype=dtype_latents)
        self.z_std  = z_std.to(device=device, dtype=dtype_latents)
        self.z_scale = (2.0 * self.z_std).to(device=device, dtype=dtype_latents)
        
        self.B  = int(batch_size)
        self.T  = int(total_T)
        self.device = device
        self.ratio = int(self.ae.downsampling_ratio)
        self.n_ch  = int(self.ae.out_channels)

        self.use_cuda_graph = bool(use_cuda_graph)
        self.compile_warmup_steps = int(compile_warmup_steps)
        self.audio = torch.zeros(self.B, self.n_ch, self.T * self.ratio,
                                 device=self.device, dtype=self.dec_dtype)
        self.emitted_samples = 0
        
        self._stream_cm = None
        self._step_graphed: CUDAGraphed | None = None
        self._step = self._step_no_graph 

        self._enter_streaming_and_capture_graph()
    def _enter_streaming(self):
        assert self._stream_cm is None, "Streaming already active"
        self._stream_cm = self._dec_stream_owner.streaming(self.B)
        self._stream_cm.__enter__()

    def _exit_streaming(self):
        if self._stream_cm is not None:
            self._stream_cm.__exit__(None, None, None)
            self._stream_cm = None

    def _step_no_graph(self, lat_step: torch.Tensor):
        return self.dec(lat_step)
    
    def _capture_graph(self):
        use_cuda = (isinstance(self.device, torch.device) and self.device.type == "cuda") \
                or (isinstance(self.device, str) and str(self.device).startswith("cuda"))
        if not (self.use_cuda_graph and use_cuda and torch.cuda.is_available()):
            self._step = self._step_no_graph
            self._step_graphed = None
            return
        if self._step_graphed is None:
            self._step_graphed = CUDAGraphed(self._step_no_graph, warmup_steps=self.compile_warmup_steps, disable=False)
        else:
            self._step_graphed.reset(warmup_steps=self.compile_warmup_steps)

        with torch.inference_mode():
            dummy = torch.zeros(self.B, self.ae.latent_dim, 1,
                                device=self.device, dtype=self.dec_dtype)
            _ = self._step_graphed(dummy)
            torch.cuda.synchronize() if use_cuda else None

        self._step = self._step_graphed

    def _enter_streaming_and_capture_graph(self):
        self._enter_streaming()
        self._capture_graph()

    @torch.no_grad()
    def finalize(self) -> None:
        self._step = self._step_no_graph
        if self._step_graphed is not None:
            self._step_graphed.reset(warmup_steps=0)
        self._exit_streaming()
    
    @torch.no_grad()
    def push_token(self, a_t: torch.Tensor) -> int:
        B, D = a_t.shape
        assert B == self.B and D == self.ae.latent_dim, "token shape mismatch"

        # [B,D] -> denorm -> [B,1,D] -> [B,D,1]
        den = a_t.unsqueeze(1).to(self.lat_dtype)       # [B,1,D]
        den.mul_(self.z_scale) 
        den.add_(self.z_mean)
        lat_step = den.transpose(1, 2)                  # [B,D,1]
        if lat_step.dtype != self.dec_dtype:
            lat_step = lat_step.to(self.dec_dtype)
        lat_step = lat_step.contiguous()
        
        # y = self.dec(lat_step) 
        y = self._step(lat_step)
        assert y.dim() == 3 and y.shape[0] == self.B and y.shape[1] == self.n_ch
        step_len = int(y.shape[-1])

        L0, L1 = self.emitted_samples, self.emitted_samples + step_len
        self.audio[:, :, L0:L1].copy_(y)
        self.emitted_samples = L1
        return step_len

    @torch.no_grad()
    def reset_buffers(self, total_T: int):
        self.T = int(total_T)
        self.audio = torch.zeros(
            self.B, self.n_ch, self.T * self.ratio,
            device=self.device, dtype=torch.float32
        )
        self.emitted_samples = 0
        self._exit_streaming()
        self._enter_streaming_and_capture_graph()
    def __del__(self):
        try:
            self.finalize()
        except Exception:
            pass