from .sd_motion import TemporalBlock
import torch



class SDXLMotionModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.motion_modules = torch.nn.ModuleList([
            TemporalBlock(8, 320//8, 320, eps=1e-6),
            TemporalBlock(8, 320//8, 320, eps=1e-6),

            TemporalBlock(8, 640//8, 640, eps=1e-6),
            TemporalBlock(8, 640//8, 640, eps=1e-6),

            TemporalBlock(8, 1280//8, 1280, eps=1e-6),
            TemporalBlock(8, 1280//8, 1280, eps=1e-6),

            TemporalBlock(8, 1280//8, 1280, eps=1e-6),
            TemporalBlock(8, 1280//8, 1280, eps=1e-6),
            TemporalBlock(8, 1280//8, 1280, eps=1e-6),

            TemporalBlock(8, 640//8, 640, eps=1e-6),
            TemporalBlock(8, 640//8, 640, eps=1e-6),
            TemporalBlock(8, 640//8, 640, eps=1e-6),

            TemporalBlock(8, 320//8, 320, eps=1e-6),
            TemporalBlock(8, 320//8, 320, eps=1e-6),
            TemporalBlock(8, 320//8, 320, eps=1e-6),
        ])
        self.call_block_id = {
            0: 0,
            2: 1,
            7: 2,
            10: 3,
            15: 4,
            18: 5,
            25: 6,
            28: 7,
            31: 8,
            35: 9,
            38: 10,
            41: 11,
            44: 12,
            46: 13,
            48: 14,
        }
        
    def forward(self):
        pass

    @staticmethod
    def state_dict_converter():
        return SDMotionModelStateDictConverter()


class SDMotionModelStateDictConverter:
    def __init__(self):
        pass

    def from_diffusers(self, state_dict):
        rename_dict = {
            "norm": "norm",
            "proj_in": "proj_in",
            "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
            "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
            "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
            "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
            "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
            "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
            "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
            "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
            "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
            "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
            "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
            "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
            "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
            "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
            "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
            "proj_out": "proj_out",
        }
        name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
        name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
        name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
        state_dict_ = {}
        last_prefix, module_id = "", -1
        for name in name_list:
            names = name.split(".")
            prefix_index = names.index("temporal_transformer") + 1
            prefix = ".".join(names[:prefix_index])
            if prefix != last_prefix:
                last_prefix = prefix
                module_id += 1
            middle_name = ".".join(names[prefix_index:-1])
            suffix = names[-1]
            if "pos_encoder" in names:
                rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
            else:
                rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
            state_dict_[rename] = state_dict[name]
        return state_dict_
    
    def from_civitai(self, state_dict):
        return self.from_diffusers(state_dict)
