import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist
from typing import Tuple, Dict

from .base_model import BaseModel
from ..utils import DistriConfig


class NaivePatchUNet(BaseModel):  # for Patch Parallelism
    def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
        assert isinstance(model, UNet2DConditionModel)
        super(NaivePatchUNet, self).__init__(model, distri_config)

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: torch.Tensor or float or int,
        encoder_hidden_states: torch.Tensor,
        class_labels: torch.Tensor or None = None,
        timestep_cond: torch.Tensor or None = None,
        attention_mask: torch.Tensor or None = None,
        cross_attention_kwargs: Dict[str, any] or None = None,
        added_cond_kwargs: Dict[str, torch.Tensor] or None = None,
        down_block_additional_residuals: Tuple[torch.Tensor] or None = None,
        mid_block_additional_residual: torch.Tensor or None = None,
        down_intrablock_additional_residuals: Tuple[torch.Tensor] or None = None,
        encoder_attention_mask: torch.Tensor or None = None,
        return_dict: bool = True,
        record: bool = False,
    ):
        distri_config = self.distri_config
        b, c, h, w = sample.shape
        assert (
            class_labels is None
            and timestep_cond is None
            and attention_mask is None
            and cross_attention_kwargs is None
            and down_block_additional_residuals is None
            and mid_block_additional_residual is None
            and down_intrablock_additional_residuals is None
            and encoder_attention_mask is None
        )

        if distri_config.use_cuda_graph and not record:
            static_inputs = self.static_inputs

            if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
                assert b == 2
                batch_idx = distri_config.batch_idx()
                sample = sample[batch_idx : batch_idx + 1]
                timestep = (
                    timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
                )
                encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
                if added_cond_kwargs is not None:
                    for k in added_cond_kwargs:
                        added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]

            assert static_inputs["sample"].shape == sample.shape
            static_inputs["sample"].copy_(sample)
            if torch.is_tensor(timestep):
                if timestep.ndim == 0:
                    for b in range(static_inputs["timestep"].shape[0]):
                        static_inputs["timestep"][b] = timestep.item()
                else:
                    assert static_inputs["timestep"].shape == timestep.shape
                    static_inputs["timestep"].copy_(timestep)
            else:
                for b in range(static_inputs["timestep"].shape[0]):
                    static_inputs["timestep"][b] = timestep
            assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
            static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
            if added_cond_kwargs is not None:
                for k in added_cond_kwargs:
                    assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
                    static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])

            graph_idx = 0
            if distri_config.split_scheme == "alternate":
                graph_idx = self.counter % 2
            self.cuda_graphs[graph_idx].replay()
            output = self.static_outputs[graph_idx]
        else:
            if distri_config.world_size == 1:
                output = self.model(
                    sample,
                    timestep,
                    encoder_hidden_states,
                    class_labels=class_labels,
                    timestep_cond=timestep_cond,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    down_block_additional_residuals=down_block_additional_residuals,
                    mid_block_additional_residual=mid_block_additional_residual,
                    down_intrablock_additional_residuals=down_intrablock_additional_residuals,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
                assert b == 2
                batch_idx = distri_config.batch_idx()
                sample = sample[batch_idx : batch_idx + 1]
                timestep = (
                    timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
                )
                encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
                if added_cond_kwargs is not None:
                    new_added_cond_kwargs = {}
                    for k in added_cond_kwargs:
                        new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
                    added_cond_kwargs = new_added_cond_kwargs

                if distri_config.split_scheme == "row":
                    split_dim = 2
                elif distri_config.split_scheme == "col":
                    split_dim = 3
                elif distri_config.split_scheme == "alternate":
                    split_dim = 2 if self.counter % 2 == 0 else 3
                else:
                    raise NotImplementedError

                if split_dim == 2:
                    sample = sample.view(1, c, distri_config.n_device_per_batch, -1, w)[:, :, distri_config.split_idx()]
                else:
                    assert split_dim == 3
                    sample = sample.view(1, c, h, distri_config.n_device_per_batch, -1)[
                        ..., distri_config.split_idx(), :
                    ]

                output = self.model(
                    sample,
                    timestep,
                    encoder_hidden_states,
                    class_labels=class_labels,
                    timestep_cond=timestep_cond,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    down_block_additional_residuals=down_block_additional_residuals,
                    mid_block_additional_residual=mid_block_additional_residual,
                    down_intrablock_additional_residuals=down_intrablock_additional_residuals,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
                if self.output_buffer is None:
                    self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
                if self.buffer_list is None:
                    self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)]
                dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False)
                buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list]
                torch.cat(buffer_list[: distri_config.n_device_per_batch], dim=split_dim, out=self.output_buffer[0:1])
                torch.cat(buffer_list[distri_config.n_device_per_batch :], dim=split_dim, out=self.output_buffer[1:2])
                output = self.output_buffer
            else:
                if distri_config.split_scheme == "row":
                    split_dim = 2
                elif distri_config.split_scheme == "col":
                    split_dim = 3
                elif distri_config.split_scheme == "alternate":
                    split_dim = 2 if self.counter % 2 == 0 else 3
                else:
                    raise NotImplementedError

                if split_dim == 2:
                    sliced_sample = sample.view(b, c, distri_config.n_device_per_batch, -1, w)[
                        :, :, distri_config.split_idx()
                    ]
                else:
                    assert split_dim == 3
                    sliced_sample = sample.view(b, c, h, distri_config.n_device_per_batch, -1)[
                        ..., distri_config.split_idx(), :
                    ]

                output = self.model(
                    sliced_sample,
                    timestep,
                    encoder_hidden_states,
                    class_labels=class_labels,
                    timestep_cond=timestep_cond,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    down_block_additional_residuals=down_block_additional_residuals,
                    mid_block_additional_residual=mid_block_additional_residual,
                    down_intrablock_additional_residuals=down_intrablock_additional_residuals,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
                if self.output_buffer is None:
                    self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
                if self.buffer_list is None:
                    self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)]
                dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False)
                buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list]
                torch.cat(buffer_list, dim=split_dim, out=self.output_buffer)
                output = self.output_buffer
            if record:
                if self.static_inputs is None:
                    self.static_inputs = {
                        "sample": sample,
                        "timestep": timestep,
                        "encoder_hidden_states": encoder_hidden_states,
                        "added_cond_kwargs": added_cond_kwargs,
                    }
                self.synchronize()

        if return_dict:
            output = UNet2DConditionOutput(sample=output)
        else:
            output = (output,)

        self.counter += 1
        return output

    @property
    def add_embedding(self):
        return self.model.add_embedding
