from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass
from einops import rearrange

import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel

from src.utils.project import HarmonicEmbedding, process_cameras

@dataclass
class RenderConfig:
    ray_start: float = 0.5
    ray_end: float = 2.8
    resolution: int = 32
    n_samples: int = 8
    disparity_space_sampling: bool = False
    box_warp: float = 1.0


class MVModel(nn.Module):
    def __init__(
        self,
        unet: Optional[UNet2DConditionModel] = None
    ):
        super().__init__()
        self.unet = unet
        self.triplane_decoder = nn.ModuleList()
        
    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def forward(self, latents, timestep, prompt_embd, meta):

        b, m, c, h, w = latents.shape
        meta["video_length"] = m

        hidden_states = rearrange(latents, "b m c h w -> b c m h w")
        prompt_embd = rearrange(prompt_embd, "b m l c -> (b m) l c")
        # 1. process timesteps
        timestep = timestep.reshape(-1)
        t_emb = self.unet.time_proj(timestep)  # (bs, 320)
        emb = self.unet.time_embedding(t_emb)  # (bs, 1280)
        emb = rearrange(emb, "(b m) c -> b m c", m=m)[:, 0, :]
        
        hidden_states = self.unet.conv_in(hidden_states)  # bs*m, 320, 64, 64

        # unet
        # a. downsample
        down_block_res_samples = (hidden_states,)
        for i, downsample_block in enumerate(self.unet.down_blocks):
            if (
                hasattr(downsample_block, "has_cross_attention")
                and downsample_block.has_cross_attention
            ):
                for resnet, attn in zip(
                    downsample_block.resnets, downsample_block.attentions
                ):
                    hidden_states = resnet(hidden_states, emb)
                    hidden_states = attn(
                        hidden_states, encoder_hidden_states=prompt_embd, meta=meta
                    ).sample
                    down_block_res_samples += (hidden_states,)
            else:
                for resnet in downsample_block.resnets:
                    hidden_states = resnet(hidden_states, emb)
                    down_block_res_samples += (hidden_states,)

            if downsample_block.downsamplers is not None:
                for downsample in downsample_block.downsamplers:
                    hidden_states = downsample(hidden_states)
                down_block_res_samples += (hidden_states,)
        # b. mid
        hidden_states = self.unet.mid_block.resnets[0](hidden_states, emb)
        for attn, resnet in zip(
            self.unet.mid_block.attentions, self.unet.mid_block.resnets[1:]
        ):
            hidden_states = attn(
                hidden_states, encoder_hidden_states=prompt_embd, meta=meta
            ).sample
            hidden_states = resnet(hidden_states, emb)

        # c. upsample
        current_upblock_id = 0
        for i, upsample_block in enumerate(self.unet.up_blocks):
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[
                : -len(upsample_block.resnets)
            ]

            if (
                hasattr(upsample_block, "has_cross_attention")
                and upsample_block.has_cross_attention
            ):
                for resnet, attn in zip(
                    upsample_block.resnets, upsample_block.attentions
                ):
                    res_hidden_states = res_samples[-1]
                    res_samples = res_samples[:-1]
                    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
                    hidden_states = resnet(hidden_states, emb)
                    hidden_states = attn(
                        hidden_states, encoder_hidden_states=prompt_embd, meta=meta
                    ).sample
            else:
                for resnet in upsample_block.resnets:
                    res_hidden_states = res_samples[-1]
                    res_samples = res_samples[:-1]
                    hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
                    hidden_states = resnet(hidden_states, emb)

            if len(self.triplane_decoder) > 0 and i in self.insert_up_layers:
                hidden_states = self.triplane_decoder(hidden_states, meta["target_cameras"], emb)

            if upsample_block.upsamplers is not None:
                for upsample in upsample_block.upsamplers:
                    hidden_states = upsample(hidden_states)

        # 4.post-process
        sample = self.unet.conv_norm_out(hidden_states)
        sample = self.unet.conv_act(sample)
        sample = self.unet.conv_out(sample)
        sample = rearrange(sample, "b c m h w -> b m c h w", m=m)
        return sample
