import torch
import torch.nn as nn
from collections import OrderedDict
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from diffusers.loaders import FromOriginalControlnetMixin
from typing import Any, Dict, List, Optional, Tuple, Union


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def get_parameter_dtype(parameter: torch.nn.Module):
    try:
        params = tuple(parameter.parameters())
        if len(params) > 0:
            return params[0].dtype

        buffers = tuple(parameter.buffers())
        if len(buffers) > 0:
            return buffers[0].dtype

    except StopIteration:
        # For torch.nn.DataParallel compatibility in PyTorch 1.5

        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
            return tuples

        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
        first_tuple = next(gen)
        return first_tuple[1].dtype

class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class ResnetBlock(nn.Module):

    def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
        super().__init__()
        ps = ksize // 2
        if in_c != out_c or sk == False:
            self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
        else:
            self.in_conv = None
        self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.act = nn.ReLU()
        self.block2 = zero_module(nn.Conv2d(out_c, out_c, ksize, 1, ps))
        if sk == False:
            self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps)
        else:
            self.skep = None

        self.down = down
        if self.down == True:
            self.down_opt = Downsample(in_c, use_conv=use_conv)

    def forward(self, x):
        if self.down == True:
            x = self.down_opt(x)
        if self.in_conv is not None:  # edit
            x = self.in_conv(x)

        h = self.block1(x)
        h = self.act(h)
        h = self.block2(h)
        if self.skep is not None:
            return h + self.skep(x)
        else:
            return h + x

def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

class Adapter_XL(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):

    def __init__(self, channels=[320, 640, 1280], nums_rb=[4, 3, 6], down_layer=[3, 6, 9], cin=192, ksize=3, sk=False, use_conv=True):
        super(Adapter_XL, self).__init__()
        self.unshuffle = nn.PixelUnshuffle(8)
        self.channels = channels
        self.nums_rb = nums_rb
        self.body = []
        num_layer = 0
        for i in range(len(channels)):
            for j in range(nums_rb[i]):
                if_down = num_layer in down_layer
                if (i == 2) and (j == 0):
                    self.body.append(
                        ResnetBlock(channels[i - 1], channels[i], down=if_down, ksize=ksize, sk=sk, use_conv=use_conv))
                elif (i == 1) and (j == 0):
                    self.body.append(
                        ResnetBlock(channels[i - 1], channels[i], down=if_down, ksize=ksize, sk=sk, use_conv=use_conv))
                else:
                    self.body.append(
                        ResnetBlock(channels[i], channels[i], down=if_down, ksize=ksize, sk=sk, use_conv=use_conv))
                num_layer += 1
        self.body = nn.ModuleList(self.body)
        self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)

    @property
    def dtype(self) -> torch.dtype:
        """
        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
        """
        return get_parameter_dtype(self)

    def forward(self, x):
        bz, f, c, w, h = x.shape
        x = x.reshape(bz*f, c, w, h)
        # unshuffle
        x = self.unshuffle(x)
        # extract features
        features = []
        x = self.conv_in(x)
        num_layer = 0
        for i in range(len(self.channels)):
            for j in range(self.nums_rb[i]):
                x = self.body[num_layer](x)
                num_layer += 1
                _, c, w, h = x.shape
                #feature = x.reshape(f, c, w, h)
                features.append(x)
        down_block_res_samples = features[:-1]
        mid_block_res_sample = features[-1]
        return down_block_res_samples, mid_block_res_sample
    
from diffusers.models.resnet import SpatioTemporalResBlock
from diffusers.models.embeddings import Timesteps, TimestepEmbedding


class SVD_Adaptor(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
    
    def __init__(self, channels=[320, 640, 1280], nums_rb=[4, 3, 6], down_layer=[3, 6, 9], cin=192, ksize=3, sk=False, use_conv=True):
        super(SVD_Adaptor, self).__init__()
        self.channels = channels
        self.nums_rb = nums_rb
        ## time embeding
        time_embed_dim = channels[0] * 4
        addition_time_embed_dim = 256
        projection_class_embeddings_input_dim = 768
        self.time_proj = Timesteps(channels[0], True, downscale_freq_shift=0)
        self.time_embedding = TimestepEmbedding(channels[0], time_embed_dim)
        self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
        
        num_layer = 0
        self.body = []
        self.zero_conv = []
        self.unshuffle = nn.PixelUnshuffle(8)
        for index in range(len(channels)):
            for block_num in range(nums_rb[index]):
                if (index == 2) and (block_num == 0): in_channels = channels[index-1]
                elif (index == 1) and (block_num == 0): in_channels = channels[index-1]
                else: in_channels = channels[index]
                if_down = num_layer in down_layer
                model = SpatioTemporalResBlock(in_channels=in_channels, 
                                               out_channels=channels[index],
                                               temb_channels=time_embed_dim,
                                               down=if_down)
                conv_out = zero_module(nn.Conv2d(channels[index], channels[index], 3, 1, 1))
                self.zero_conv.append(conv_out)
                self.body.append(model)
                num_layer += 1
        self.body = nn.ModuleList(self.body)
        self.zero_conv = nn.ModuleList(self.zero_conv)
        self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
        

    def get_time_embed(self, timesteps, added_time_ids, batch_size, num_frames):
        timesteps = timesteps.expand(batch_size)
        t_emb = self.time_proj(timesteps)
        emb = self.time_embedding(t_emb)
        time_embeds = self.add_time_proj(added_time_ids.flatten())
        time_embeds = time_embeds.reshape((batch_size, -1))
        time_embeds = time_embeds.to(emb.dtype)
        aug_emb = self.add_embedding(time_embeds)
        emb = emb + aug_emb # 1, 1280
        emb = emb.repeat_interleave(num_frames, dim=0)
        return emb
    
    def forward(self, x, timesteps, added_time_ids):
        bz, f, c, w, h = x.shape
        image_only_indicator = torch.zeros(bz, f).to( dtype=x.dtype, device=x.device)
        temb = self.get_time_embed(timesteps, added_time_ids, bz, f)
        
        x = x.reshape(bz*f, c, w, h)
        x = self.unshuffle(x)
        x = self.conv_in(x)
        num_layer = 0
        features = []
        for index in range(len(self.channels)):
            for block_num in range(self.nums_rb[index]):
                x = self.body[num_layer](x, temb, image_only_indicator)
                feature = self.zero_conv[num_layer](x)
                features.append(feature)
                num_layer += 1
        down_block_res_samples = features[:-1]
        mid_block_res_sample = features[-1]
        return down_block_res_samples, mid_block_res_sample

from diffusers.models.transformer_2d import Transformer2DModel
from torch.nn import functional as F
class ReferenceEncoder(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):

    def __init__(
        self,
        conditioning_embedding_channels: int = 320,
        conditioning_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
    ):
        super().__init__()

        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
        self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
            
        
    def forward(self, conditioning):
        #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames
        #combine batch and frames dimensions
        batch_size, frames, channels, height, width = conditioning.size()
        conditioning = conditioning.view(batch_size * frames, channels, height, width)

        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)
        embedding = self.conv_out(embedding)
        
        return embedding
    
class SVD_Adaptor_Referance(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
    
    def __init__(self, channels=[320, 640, 1280], nums_rb=[4, 3, 6], down_layer=[3, 6, 9], cin=192, ksize=3, sk=False, use_conv=True):
        super(SVD_Adaptor_Referance, self).__init__()
        self.channels = channels
        self.nums_rb = nums_rb
        ## time embeding
        time_embed_dim = channels[0] * 4
        addition_time_embed_dim = 256
        projection_class_embeddings_input_dim = 768
        self.time_proj = Timesteps(channels[0], True, downscale_freq_shift=0)
        self.time_embedding = TimestepEmbedding(channels[0], time_embed_dim)
        self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
        
        num_layer = 0
        self.body = []
        self.zero_conv = []
        self.cross_atten = []
        self.unshuffle = nn.PixelUnshuffle(8)
        for index in range(len(channels)):
            model = Transformer2DModel(in_channels=channels[index], out_channels=channels[index])
            self.cross_atten.append(model)
            for block_num in range(nums_rb[index]):
                if (index == 2) and (block_num == 0): in_channels = channels[index-1]
                elif (index == 1) and (block_num == 0): in_channels = channels[index-1]
                else: in_channels = channels[index]
                if_down = num_layer in down_layer
                model = SpatioTemporalResBlock(in_channels=in_channels, 
                                               out_channels=channels[index],
                                               temb_channels=time_embed_dim,
                                               down=if_down)
                conv_out = zero_module(nn.Conv2d(channels[index], channels[index], 3, 1, 1))
                self.zero_conv.append(conv_out)
                self.body.append(model)
                num_layer += 1

        self.body = nn.ModuleList(self.body)
        self.cross_atten = nn.ModuleList(self.cross_atten)
        self.zero_conv = nn.ModuleList(self.zero_conv)
        self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
        #self.refence_encoder = ReferenceEncoder()

    def get_time_embed(self, timesteps, added_time_ids, batch_size, num_frames):
        timesteps = timesteps.expand(batch_size)
        t_emb = self.time_proj(timesteps)
        emb = self.time_embedding(t_emb)
        time_embeds = self.add_time_proj(added_time_ids.flatten())
        time_embeds = time_embeds.reshape((batch_size, -1))
        time_embeds = time_embeds.to(emb.dtype)
        aug_emb = self.add_embedding(time_embeds)
        emb = emb + aug_emb # 1, 1280
        emb = emb.repeat_interleave(num_frames, dim=0)
        return emb
    
    def forward(self, x, reference, timesteps, added_time_ids):
       #reference_embeding = self.refence_encoder(reference)
        bz, f, c, w, h = x.shape
        image_only_indicator = torch.zeros(bz, f).to( dtype=x.dtype, device=x.device)
        temb = self.get_time_embed(timesteps, added_time_ids, bz, f)
        
        x = x.reshape(bz*f, c, w, h)
        x = self.unshuffle(x)
        x = self.conv_in(x)
        num_layer = 0
        features = []
        for index in range(len(self.channels)):
            for block_num in range(self.nums_rb[index]):
                x = self.body[num_layer](x, temb, image_only_indicator)
                if block_num == 0:
                    x = self.cross_atten[index](x, reference)['sample']
                feature = self.zero_conv[num_layer](x)
                features.append(feature)
                num_layer += 1
        down_block_res_samples = features[:-1]
        mid_block_res_sample = features[-1]
        return down_block_res_samples, mid_block_res_sample
         
if __name__ == '__main__':
    model = SVD_Adaptor(cin=3*64)
    # features = model.forward(image)
    # for feature in features:
    #     print(feature.shape)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型参数总数: {total_params}")

    batch_size = 1
    num_frames = 16
    image = torch.zeros((batch_size, num_frames, 3, 320, 576))
    image_only_indicator = torch.zeros(batch_size, num_frames)
    t_embed = torch.zeros(512,)
    feature = model(image, t_embed, image_only_indicator)
    print(feature.shape)
