from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch

from diffusers import DDIMScheduler, UNet2DModel, DDIMPipeline
from diffusers.utils import BaseOutput
from torch import nn
from torch.nn import functional as F
@dataclass
class UNet2DOutput(BaseOutput):
    """
    The output of [`UNet2DModel`].

    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            The hidden states output from the last layer of the model.
    """

    sample: torch.FloatTensor



class DeepSupervisonUNet2DModel(UNet2DModel):
    def __init__(self, sample_size: int | Tuple[int, int] | None = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str] = ..., up_block_types: Tuple[str] = ..., block_out_channels: Tuple[int] = ..., layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, downsample_type: str = "conv", upsample_type: str = "conv", act_fn: str = "silu", attention_head_dim: int | None = 8, norm_num_groups: int = 32, norm_eps: float = 0.00001, resnet_time_scale_shift: str = "default", add_attention: bool = True, class_embed_type: str | None = None, num_class_embeds: int | None = None):
        super().__init__(sample_size, in_channels, out_channels, center_input_sample, time_embedding_type, freq_shift, flip_sin_to_cos, down_block_types, up_block_types, block_out_channels, layers_per_block, mid_block_scale_factor, downsample_padding, downsample_type, upsample_type, act_fn, attention_head_dim, norm_num_groups, norm_eps, resnet_time_scale_shift, add_attention, class_embed_type, num_class_embeds)
        last_channel = self.up_blocks[-1].resnets[-1].out_channels
        self.heads = nn.ModuleList()
        for i in range(len(self.up_blocks)-1):
            self.heads.append(
                nn.Conv2d(
                    self.up_blocks[i].resnets[-1].out_channels,
                    last_channel,
                    kernel_size=1,
                )
            )

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        class_labels: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DOutput, Tuple]:
        r"""
        The [`UNet2DModel`] forward method.

        Args:
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.

        Returns:
            [`~models.unet_2d.UNet2DOutput`] or `tuple`:
                If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
                returned where the first element is the sample tensor.
        """
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when doing class conditioning")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

        # 2. pre-process
        skip_sample = sample
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.mid_block(sample, emb)

        # 5. up
        skip_sample = None
        output_list = []
        for upsample_block in self.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
            output_list.append(sample)
            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
            else:
                sample = upsample_block(sample, res_samples, emb)
        
        item_list = []
        for head, item in zip(self.heads, output_list[:-1]):
            item = head(item)
            item_list.append(item)
        item_list.append(sample)
        # 6. post-process
        sample_list = []
        out_size = sample.shape[-2:]
        for item in item_list:
            item = F.interpolate(item, size=out_size, mode="bilinear", align_corners=False)
            sample = self.conv_norm_out(item)
            sample = self.conv_act(item)
            sample = self.conv_out(item)
            sample_list.append(sample)
        sample_list = sample_list[2:]
        return sample_list