import torch
from tqdm.auto import tqdm
from types import SimpleNamespace
from typing import Tuple
from diffusion.respace import SpacedDiffusion
from mask_generator import VideoMaskGenerator



def sample_unconditional(
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    args: SimpleNamespace,
    mask_generator: VideoMaskGenerator,
    diffusion_sampler: SpacedDiffusion,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample unconditional time-frequency data using a diffusion model.
    Args:
        model (torch.nn.Module): The diffusion model to sample from.
        data_loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        args (SimpleNamespace): Arguments containing device and other configurations.
        mask_generator (VideoMaskGenerator): Mask generator for video data.
        diffusion_sampler (SpacedDiffusion): Diffusion sampler for generating samples.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple containing the generated time-frequency data and the original time-frequency data.
    """

    model.eval()
    generated_timefreq_data = []
    original_timefreq_data = []

    bf16_supported = torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if bf16_supported else torch.float16
    
    with torch.no_grad():
        for _, model_input in tqdm(
            enumerate(data_loader),
            total=len(data_loader),
            desc=f"Sampling",
        ):
            x, tc = model_input
            x = x.to(device=args.device, dtype=torch.float)
            raw_x = x.clone()

            mask = mask_generator(x.shape[0], args.device, idx=args.mask_choice)
            z = torch.randn_like(x).permute(0, 2, 1, 3, 4)
            if args.time_covariates:
                model_kwargs = {"tc": tc.to(args.device, dtype=dtype)}
            else:
                model_kwargs = None
                
            with torch.amp.autocast(device_type=args.device, dtype=dtype):
                samples = diffusion_sampler.p_sample_loop(
                    model.forward,
                    z.shape,
                    z,
                    model_kwargs=model_kwargs,
                    clip_denoised=True,
                    progress=False,
                    device=args.device,
                    raw_x=x,
                    mask=mask,
                )
            samples = samples.permute(1, 0, 2, 3, 4) * mask + x.permute(
                2, 0, 1, 3, 4
            ) * (1 - mask)
            samples = samples.permute(1, 2, 0, 3, 4)
            assert (
                x.shape == samples.shape
            ), f"Shape mismatch: {x.shape} vs {samples.shape}"
                
            generated_timefreq_data.append(samples)
            original_timefreq_data.append(raw_x)
    return (
        torch.cat(generated_timefreq_data, dim=0).cpu(),
        torch.cat(original_timefreq_data, dim=0).cpu(),
    )
