class MoEMLP(nn.Module):
    def __init__(self, model_paths):
        super(MoEMLP, self).__init__()
        self.pre_trained_mlps = nn.ModuleList()
        for path in model_paths:
            mlp = nn.Sequential(
                nn.Linear(384, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 64)
            )
            packed_weights = torch.load(path, map_location='cuda')
            with torch.no_grad():
                mlp[0].weight.copy_(packed_weights['layer1'][:, :-1])
                mlp[0].bias.copy_(packed_weights['layer1'][:, -1].squeeze())
                mlp[2].weight.copy_(packed_weights['layer2'][:, :-1])
                mlp[2].bias.copy_(packed_weights['layer2'][:, -1].squeeze())
                mlp[4].weight.copy_(packed_weights['layer3'][:, :-1])
                mlp[4].bias.copy_(packed_weights['layer3'][:, -1].squeeze())
            for param in mlp.parameters():
                param.requires_grad = False
            self.pre_trained_mlps.append(mlp)
        
        self.new_mlps = nn.ModuleList()
        for _ in range(5):
            mlp = nn.Sequential(
                nn.Linear(384, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 64)
            )
            self.new_mlps.append(mlp)
        
        self.all_mlps = nn.ModuleList([*self.pre_trained_mlps, *self.new_mlps])
        
        self.gate = nn.Sequential(
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        gate_prob = self.gate(x)
        expert_outputs = []
        for mlp in self.all_mlps:
            expert_outputs.append(mlp(x))
        expert_outputs = torch.stack(expert_outputs, dim=0)
        gate_prob_expanded = gate_prob.unsqueeze(-1)
        gate_prob_expanded = gate_prob_expanded.permute(2, 0, 1, 3)
        weighted_output = torch.sum(expert_outputs * gate_prob_expanded, dim=0)
        return weighted_output

class DiffusionUnetHybridImagePolicy(BasePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_global_cond=True,
            diffusion_step_embed_dim=256,
            down_dims=(256,512,1024),
            kernel_size=5,
            n_groups=8,
            condition_type="film",
            use_down_condition=True,
            use_mid_condition=True,
            use_up_condition=True,
            frozen=False,
            feature_dim=64,
            **kwargs):
        
        super().__init__()
        self.feature_dim = feature_dim
        
        model_paths = ["path1", 
                       "path2", 
                       "path3", 
                       "path4", 
                       "path5"]
        
        self.moe_mlp = MoEMLP(model_paths).to('cuda')
        
        from transformers import AutoImageProcessor, AutoModel
        self.dinov2_model, self.dinov2_processor = (
            AutoModel.from_pretrained('facebook/dinov2-small').to('cuda').eval(),
            AutoImageProcessor.from_pretrained('facebook/dinov2-small')
        )
        
        if frozen:
            for param in self.dinov2_model.parameters():
                param.requires_grad = False

        action_shape = shape_meta['action']['shape']
        action_dim = action_shape[0]
        robot_state_dim = shape_meta['obs']['agent_pos']['shape'][0]
        obs_feature_dim = self.feature_dim + robot_state_dim
        input_dim = action_dim + obs_feature_dim
        global_cond_dim = None
        if obs_as_global_cond:
            input_dim = action_dim
            global_cond_dim = obs_feature_dim * n_obs_steps

        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            condition_type=condition_type,
            use_down_condition=use_down_condition,
            use_mid_condition=use_mid_condition,
            use_up_condition=use_up_condition,
        )
        
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

    def conditional_sample(self, condition_data, condition_mask, local_cond=None, global_cond=None, generator=None, **kwargs):
        model = self.model
        scheduler = self.noise_scheduler
        trajectory = torch.randn(size=condition_data.shape, dtype=condition_data.dtype, device=condition_data.device, generator=generator)
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            trajectory[condition_mask] = condition_data[condition_mask]
            model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
            trajectory = scheduler.step(model_output, t, trajectory, generator=generator, **kwargs).prev_sample
        
        trajectory[condition_mask] = condition_data[condition_mask]        
        return trajectory

    def get_image_feature(self, feature):
        moe_features = self.moe_mlp(feature)
        return moe_features

    def predict_action(self, obs_dict):
        img = obs_dict['img']
        agent_pos = self.normalizer.normalize({'agent_pos':obs_dict['agent_pos']})['agent_pos']
        B, To = img.shape[:2]
        T = self.horizon
        Da = self.action_dim
        To = self.n_obs_steps
        
        if img.shape[2] == 3:
            img = img.permute(0,1,3,4,2)
        img_shape = img.shape
        img = img.reshape(-1, *img_shape[2:])
        img = img * 255
        np_img = img.cpu().numpy().astype(np.uint8)
        dino_input = self.dinov2_processor(images=np_img, return_tensors="pt").to('cuda')
        dinov2_features = self.dinov2_model(**dino_input)[0][:,0,:]
        dinov2_features = dinov2_features.unsqueeze(1)
        feature = self.get_image_feature(dinov2_features)
        feature = feature.reshape(img_shape[0], img_shape[1], -1)
        
        device = self.device
        dtype = self.dtype
        
        if self.obs_as_global_cond:
            feature = feature[:,:self.n_obs_steps].reshape(-1,*feature.shape[2:])
            feature = feature.reshape(B, -1, self.feature_dim)
            global_cond = torch.cat([agent_pos[:,:self.n_obs_steps], feature], dim=-1).reshape(B, -1)
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            raise NotImplementedError("Not implemented yet")

        nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond)
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]
        
        return {'action': action, 'action_pred': action_pred}

    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        agent_pos = self.normalizer.normalize({'agent_pos':batch['obs']['agent_pos']})['agent_pos']
        feature = batch['obs']['feature'][:,:self.n_obs_steps]
        feature = self.get_image_feature(feature)
        feature = feature.reshape(batch['obs']['img'].shape[0], batch['obs']['img'].shape[1], -1)
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]

        trajectory = nactions
        cond_data = trajectory
        
        if self.obs_as_global_cond:
            feature = feature[:,:self.n_obs_steps].reshape(-1,*feature.shape[2:])
            feature = feature.reshape(batch_size, -1, self.feature_dim)
            global_cond = torch.cat([agent_pos[:,:self.n_obs_steps], feature], dim=-1).reshape(batch_size, -1)
        else:
            raise NotImplementedError("Not implemented yet")

        condition_mask = self.mask_generator(trajectory.shape)
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device).long()
        noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
        loss_mask = ~condition_mask
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        pred = self.model(noisy_trajectory, timesteps, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = loss.mean()

        return loss, {'bc_loss': loss.item()}
    
    def forward(self, batch):
        return self.compute_loss(batch)