import torch
import torch.nn as nn


class MPE_Soft4(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.num_experts_per_tok = 4
        self.in_dim = config.mm_hidden_size
        self.out_dim = config.hidden_size
        
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.mm_hidden_size, config.hidden_size),
                nn.GELU(),
                nn.Linear(config.hidden_size, config.hidden_size)
            ) for _ in range(10)  # 10个专家
        ])
        self.gate = nn.Linear(config.mm_hidden_size, 10, bias=False)
    
    def forward(self, inputs):
        # inputs = torch.unsqueeze(input=inputs, dim=0)
        fig_shape = inputs.shape
        inputs = inputs.contiguous().view(-1, self.in_dim)
        
        gate_logits = self.gate(inputs)
        weights, selected_experts = torch.topk(
            gate_logits, self.num_experts_per_tok
        )
        
        weights = torch.nn.functional.softmax(weights, dim=-1)
        
        # results = torch.zeros_like(inputs)
        out_dim = list(inputs.shape)
        out_dim[-1] = self.out_dim
        results = torch.zeros(size=out_dim, device=inputs.device, dtype=inputs.dtype,)
        
        for i, expert in enumerate(self.experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
            torch.cuda.empty_cache()
            
        results = results.contiguous().view(fig_shape[0], fig_shape[1], -1)
        
        return results
        

    @property
    def config(self):
        return {"mm_projector_type": "mpe_soft4"}


class MPE_res(nn.Module):
    def __init__(self, config):
        super().__init__() # 创建4个专家，每个专家为两层MLP（与原投影层结构一致）
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.mm_hidden_size, config.hidden_size),
                nn.GELU(),
                nn.Linear(config.hidden_size, config.hidden_size)
            ) for _ in range(4)  # 4个专家
        ])
    
    def forward(self, inputs, resolution=14):
        # import pdb; pdb.set_trace()
        if resolution == 27:
            expert_indice = 0
        elif resolution == 18:
            expert_indice = 1
        elif resolution == 14:
            expert_indice = 2
        elif resolution == 11:
            expert_indice = 3
        else:
            print(f"Unsupported resolution: {resolution}, using default resolution 14")
            expert_indice = 2
        output = self.experts[expert_indice](inputs)
        
        return output

    @property
    def config(self):
        return {"mm_projector_type": "mpe_res"}
    

class MPE_Hard4(nn.Module):
    def __init__(self, config):
        super().__init__() # 创建10个专家，每个专家为两层MLP（与原投影层结构一致）
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.mm_hidden_size, config.hidden_size),
                nn.GELU(),
                nn.Linear(config.hidden_size, config.hidden_size)
            ) for _ in range(10)  # 10个专家
        ])
    
    def forward(self, inputs):
        # 根据输入帧数选择专家组合
        frames = inputs.size(0)  # 获取帧数（假设x形状为[frames, feature_dim]）
        frame_shape = inputs.size(1)
        # import pdb; pdb.set_trace()
        if frames <= 3: # <= 3 images
            expert_indices = [0, 1, 2, 3]
        elif frame_shape == 729: # 27**2 = 729
            expert_indices = [0, 3, 4, 5]
        elif frame_shape == 324: # 18**2 = 324
            expert_indices = [0, 5, 6, 7]
        else: # 14**2 = 196
            expert_indices = [0, 7, 8, 9]
        
        # 提取选中的专家并计算输出
        selected_experts = [self.experts[i] for i in expert_indices]
        expert_outputs = [expert(inputs) for expert in selected_experts]
        
        # 计算加权输出
        weights = [0.4, 0.2, 0.2, 0.2]  # expert[0] 40% others 20%
        weighted_outputs = [w * output for w, output in zip(weights, expert_outputs)]
        output = torch.sum(torch.stack(weighted_outputs), dim=0)
        
        return output

    @property
    def config(self):
        return {"mm_projector_type": "mpe_hard4"}
    

class MPE_Hard4_Auto(nn.Module):
    def __init__(self, config):
        super().__init__() # 创建10个专家，每个专家为两层MLP（与原投影层结构一致）
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.mm_hidden_size, config.hidden_size),
                nn.GELU(),
                nn.Linear(config.hidden_size, config.hidden_size)
            ) for _ in range(10)  # 10个专家
        ])
        # 定义可学习权重（作为参数），按照 [0.4, 0.2, 0.2, 0.2] 初始化

        initial_weights = torch.tensor([[0.4, 0.2, 0.2, 0.2], 
                                        [0.4, 0.2, 0.2, 0.2], 
                                        [0.4, 0.2, 0.2, 0.2], 
                                        [0.4, 0.2, 0.2, 0.2]] , dtype=torch.float32)
        self.weights = nn.Parameter(initial_weights)  # 将权重设置为可学习参数
        
    
    def forward(self, inputs):
        # 根据输入帧数选择专家组合
        frames = inputs.size(0)  # 获取帧数（假设x形状为[frames, feature_dim]）
        frame_shape = inputs.size(1)
        # import pdb; pdb.set_trace()
        if frames <= 3: # <= 3 images
            weight_indice = 0
            expert_indices = [0, 1, 2, 3]
        elif frame_shape == 729: # 27**2 = 729
            weight_indice = 1
            expert_indices = [0, 3, 4, 5]
        elif frame_shape == 324: # 18**2 = 324
            weight_indice = 2
            expert_indices = [0, 5, 6, 7]
        else: # 14**2 = 196
            weight_indice = 3
            expert_indices = [0, 7, 8, 9]
        
        # 提取选中的专家并计算输出
        selected_experts = [self.experts[i] for i in expert_indices]
        expert_outputs = [expert(inputs) for expert in selected_experts]
        
        # 确保权重的归一化
        normalized_weights = torch.nn.functional.softmax(self.weights[weight_indice], dim=0)
        
        # 计算加权输出
        weighted_outputs = [w * output for w, output in zip(normalized_weights, expert_outputs)]
        output = torch.sum(torch.stack(weighted_outputs), dim=0)        
        # # 计算加权输出
        # weights = [0.4, 0.2, 0.2, 0.2]  # expert[0] 40% others 20%
        # weighted_outputs = [w * output for w, output in zip(weights, expert_outputs)]
        # output = torch.sum(torch.stack(weighted_outputs), dim=0)
        
        return output

    @property
    def config(self):
        return {"mm_projector_type": "mpe_hard4_auto"}


# class Conf():
#     def __init__(self):
#         self.hidden_size = 1280
#         self.mm_hidden_size = 1152
        
# if __name__ == "__main__":
    
#     config = Conf()
#     test_layer = MPE_Soft4(config)
#     test_embedding = torch.zeros([30, 729, 1152])
#     print(test_layer(test_embedding).shape)
#     test_input = []