from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.model.vision.timm_obs_encoder import TimmObsEncoder
from diffusion_policy.common.pytorch_util import dict_apply
from torch.distributions import Categorical


class ExpertBuffer:
    def __init__(self, max_size=100, camera0_rgb_shape=(2,3,224,224),robot0_eef_pos_shape=(2,3),\
                robot0_eef_rot_axis_angle_shape=(2,6),robot0_gripper_width_shape=(2,1),\
                robot0_language_instruction_shape=(2,100),robot0_eef_rot_axis_angle_wrt_start_shape=(2,6),action_shape=(16,10)):
        self.max_size = max_size
        
        self.buffer = {
            'camera0_rgb': torch.zeros((max_size, *camera0_rgb_shape)).cuda(),
            'robot0_eef_pos': torch.zeros((max_size, *robot0_eef_pos_shape)).cuda(),
            'robot0_eef_rot_axis_angle': torch.zeros((max_size, *robot0_eef_rot_axis_angle_shape)).cuda(),
            'robot0_gripper_width': torch.zeros((max_size, *robot0_gripper_width_shape)).cuda(),
            'robot0_language_instruction': torch.zeros((max_size, *robot0_language_instruction_shape)).cuda(),
            'robot0_eef_rot_axis_angle_wrt_start': torch.zeros((max_size, *robot0_eef_rot_axis_angle_wrt_start_shape)).cuda(),
            'action': torch.zeros((max_size, *action_shape)).cuda()
        }
        
        self.current_index = 0
        self.data_count = 0

    def add(self, new_data):
        """
        {'images': data, 'img_masks': data, 'lang_tokens': data, ...}
        """
        batch_size = new_data['camera0_rgb'].shape[0]

        for key in self.buffer:
            new_tensor = new_data[key]

            new_data_count = self.data_count + batch_size

            if new_data_count <= self.max_size:
                self.buffer[key][self.data_count:self.data_count + batch_size] = new_tensor
            else:
                start_idx = self.current_index
                end_idx = (start_idx + batch_size) % self.max_size

                if end_idx < start_idx:
                    self.buffer[key][start_idx:] = new_tensor[:self.max_size - start_idx]
                    self.buffer[key][:end_idx] = new_tensor[self.max_size - start_idx:]
                else:
                    self.buffer[key][start_idx:end_idx] = new_tensor

        self.data_count = min(new_data_count, self.max_size)

        self.current_index = (self.current_index + batch_size) % self.max_size

    def get(self):
        if self.data_count < self.max_size:
            indices = torch.multinomial(torch.ones(self.data_count), self.max_size, replacement=True)
            sampled_data = {key: self.buffer[key][indices] for key in self.buffer}
        else:
            sampled_data = self.buffer
        
        return sampled_data


class ExpertBuffers:
    def __init__(self, num_experts, max_size=100, camera0_rgb_shape=(2,3,224,224),robot0_eef_pos_shape=(2,3),\
                robot0_eef_rot_axis_angle_shape=(2,6),robot0_gripper_width_shape=(2,1),\
                robot0_language_instruction_shape=(2,100),robot0_eef_rot_axis_angle_wrt_start_shape=(2,6),action_shape=(16,10)):

        self.expert_buffers = [ExpertBuffer(max_size, camera0_rgb_shape, robot0_eef_pos_shape, robot0_eef_rot_axis_angle_shape, robot0_gripper_width_shape, robot0_language_instruction_shape,\
             robot0_eef_rot_axis_angle_wrt_start_shape, action_shape) for _ in range(num_experts)]

    def add(self, expert_idx, new_data):
        self.expert_buffers[expert_idx].add(new_data)

    def get(self, expert_idx):
        return self.expert_buffers[expert_idx].get()



class DiffusionUnetTimmPolicy_di(BaseImagePolicy):
    def __init__(self,
                 shape_meta: dict,
                 noise_scheduler: DDPMScheduler,
                 obs_encoder: TimmObsEncoder,
                 num_inference_steps=None,
                 obs_as_global_cond=True,
                 diffusion_step_embed_dim=256,
                 down_dims=(256, 512, 1024),
                 kernel_size=5,
                 n_groups=8,
                 cond_predict_scale=True,
                 input_pertub=0.1,
                 inpaint_fixed_action_prefix=False,
                 train_diffusion_n_samples=1,
                 num_experts=5,
                 di=True,
                 gamma=10,
                 beta=0.1,
                 fix_timestep=False,
                 encoder_nograd=False,
                 num_samples_per_expert=4,
                 buffer_max_size=16,
                 use_flow_matching=False,
                 # parameters passed to step
                 **kwargs
                 ):
        super().__init__()

        # parse shapes
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        action_horizon = shape_meta['action']['horizon']
        # get feature dim
        obs_feature_dim = np.prod(obs_encoder.output_shape())

        # create diffusion model
        assert obs_as_global_cond
        input_dim = action_dim
        global_cond_dim = obs_feature_dim

        self.use_di=False
        if di:
            self.use_di=True
            self.num_experts=num_experts
            self.gamma=gamma
            self.beta=beta
            self.fix_timestep=fix_timestep
            self.encoder_nograd=encoder_nograd
            self.num_samples_per_expert=num_samples_per_expert
            self.buffer_max_size=buffer_max_size

            obs_shape = obs_encoder.output_shape()

            self.gating_network =nn.Sequential(
                nn.Linear(obs_shape[-1]*obs_shape[-2], 1024),
                nn.ReLU(),
                nn.Linear(1024,512),
                nn.ReLU(),
                nn.Linear(512,256),
                nn.ReLU(),
                nn.Linear(256,num_experts),
            )

            self._log_weights = -torch.log(torch.ones(self.num_experts) * self.num_experts)[None, :]

            self.data_buffer = ExpertBuffers(num_experts=self.num_experts, max_size=self.buffer_max_size,camera0_rgb_shape=(2,3,224,224),robot0_eef_pos_shape=(2,3),\
                robot0_eef_rot_axis_angle_shape=(2,6),robot0_gripper_width_shape=(2,1),\
                robot0_language_instruction_shape=(2,100),robot0_eef_rot_axis_angle_wrt_start_shape=(2,6),action_shape=(16,10))


        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale,
            num_experts=num_experts
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.normalizer = LinearNormalizer()
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.action_horizon = action_horizon  # used for training
        self.obs_as_global_cond = obs_as_global_cond
        self.input_pertub = input_pertub
        self.inpaint_fixed_action_prefix = inpaint_fixed_action_prefix
        self.train_diffusion_n_samples = int(train_diffusion_n_samples)
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps


        self.use_flow_matching=use_flow_matching

        # self.mode = "org"

    # def switch_to_moe(self):
    #     print("switch to moe mode!!!!!!")
    #     self.mode = "moe"
    #     self.model.switch_to_moe()

    def sample_noise(self, shape, device):
        noise = torch.normal(
            mean=0.0,
            std=1.0,
            size=shape,
            dtype=torch.float32,
            device=device,
        )
        return noise

    def sample_beta(self, alpha, beta, bsize, device):
        gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
        gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
        return gamma1 / (gamma1 + gamma2)

    def sample_time(self, bsize, device):
        time_beta = self.sample_beta(1.5, 1.0, bsize, device)
        time = time_beta * 0.999 + 0.001
        return time.to(dtype=torch.float32, device=device)


    # ========= inference  ============
    def conditional_sample(self,
                           condition_data,
                           condition_mask,
                           local_cond=None,
                           global_cond=None,
                           generator=None,
                           # keyword arguments to scheduler.step
                           **kwargs
                           ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape,
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)

        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t,
                                 local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory,
                generator=generator,
                **kwargs
            ).prev_sample

        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]

        return trajectory

    def predict_action(self, obs_dict: Dict[str, torch.Tensor], fixed_action_prefix: torch.Tensor = None) -> Dict[
        str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        fixed_action_prefix: unnormalized action prefix
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict  # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        B = next(iter(nobs.values())).shape[0]

        # condition through global feature
        # for key in nobs.keys():
        #     print(key, np.shape(nobs[key]))
        global_cond = self.obs_encoder(nobs)

        # empty data for action
        cond_data = torch.zeros(size=(B, self.action_horizon, self.action_dim), device=self.device, dtype=self.dtype)
        cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

        if fixed_action_prefix is not None and self.inpaint_fixed_action_prefix:
            n_fixed_steps = fixed_action_prefix.shape[1]
            cond_data[:, :n_fixed_steps] = fixed_action_prefix
            cond_mask[:, :n_fixed_steps] = True
            cond_data = self.normalizer['action'].normalize(cond_data)

        # run sampling
        nsample = self.conditional_sample(
            condition_data=cond_data,
            condition_mask=cond_mask,
            local_cond=None,
            global_cond=global_cond,
            **self.kwargs)

        # unnormalize prediction
        assert nsample.shape == (B, self.action_horizon, self.action_dim)
        action_pred = self.normalizer['action'].unnormalize(nsample)

        result = {
            'action': action_pred,
            'action_pred': action_pred
        }
        return result

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def expert_forward(self, sampled_batch, use_expert_i):

        nobs = self.normalizer.normalize(sampled_batch['obs'])
        nactions = self.normalizer['action'].normalize(sampled_batch['action'])
        assert self.obs_as_global_cond
        global_cond = self.obs_encoder(nobs)

        # train on multiple diffusion samples per obs
        if self.train_diffusion_n_samples != 1:
            # repeat obs features and actions multiple times along the batch dimension
            # each sample will later have a different noise sample, effecty training
            # more diffusion steps per each obs encoder forward pass
            global_cond = torch.repeat_interleave(global_cond,
                                                  repeats=self.train_diffusion_n_samples, dim=0)
            nactions = torch.repeat_interleave(nactions,
                                               repeats=self.train_diffusion_n_samples, dim=0)

        trajectory = nactions

        if self.use_flow_matching:
            noise = self.sample_noise(trajectory.shape, trajectory.device)

            timesteps = self.sample_time(trajectory.shape[0], trajectory.device)

            time_expanded = timesteps[:, None, None]
            noisy_trajectory = time_expanded * noise + (1 - time_expanded) * trajectory
            target = noise - trajectory
        else:
            # Sample noise that we'll add to the images
            noise = torch.randn(trajectory.shape, device=trajectory.device)
            # input perturbation by adding additonal noise to alleviate exposure bias
            # reference: https://github.com/forever208/DDPM-IP
            noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, self.noise_scheduler.config.num_train_timesteps,
                (nactions.shape[0],), device=trajectory.device
            ).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_trajectory = self.noise_scheduler.add_noise(
                trajectory, noise_new, timesteps)

            pred_type = self.noise_scheduler.config.prediction_type
            if pred_type == 'epsilon':
                target = noise
            elif pred_type == 'sample':
                target = trajectory
            else:
                raise ValueError(f"Unsupported prediction type {pred_type}")


        use_expert_i=(torch.ones(nactions.shape[0])*use_expert_i).to(torch.long).to(nactions.device)

        pred = self.model(
            noisy_trajectory,
            timesteps,
            local_cond=None,
            global_cond=global_cond,
            use_expert_i=use_expert_i
        )


        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)

        if self.use_flow_matching:
            discount = 0.95
            loss = loss*(discount**torch.arange(loss.size(1), device=loss.device)).view(1, -1, 1)

        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()

        return loss

    def gating_forward(self, nactions, global_cond, use_expert_i):
        
        trajectory = nactions

        if self.use_flow_matching:
            noise = self.sample_noise(trajectory.shape, trajectory.device)

            timesteps = self.sample_time(trajectory.shape[0], trajectory.device)

            time_expanded = timesteps[:, None, None]
            noisy_trajectory = time_expanded * noise + (1 - time_expanded) * trajectory
            target = noise - trajectory
        else:
            # Sample noise that we'll add to the images
            noise = torch.randn(trajectory.shape, device=trajectory.device)
            # input perturbation by adding additonal noise to alleviate exposure bias
            # reference: https://github.com/forever208/DDPM-IP
            noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

            # Fix a timestep for each image
            if self.fix_timestep:
                timesteps = (torch.ones((nactions.shape[0],), device=trajectory.device)*self.noise_scheduler.config.num_train_timesteps//2).long()
            # Sample a random timestep for each image
            else:
                timesteps = torch.randint(
                    0, self.noise_scheduler.config.num_train_timesteps,
                    (nactions.shape[0],), device=trajectory.device
                ).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_trajectory = self.noise_scheduler.add_noise(
                trajectory, noise_new, timesteps)

            pred_type = self.noise_scheduler.config.prediction_type
            if pred_type == 'epsilon':
                target = noise
            elif pred_type == 'sample':
                target = trajectory
            else:
                raise ValueError(f"Unsupported prediction type {pred_type}")


        use_expert_i=(torch.ones(nactions.shape[0])*use_expert_i).to(torch.long).to(nactions.device)

        pred = self.model(
            noisy_trajectory,
            timesteps,
            local_cond=None,
            global_cond=global_cond,
            use_expert_i=use_expert_i
        )


        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)

        if self.use_flow_matching:
            discount = 0.95
            loss = loss*(discount**torch.arange(loss.size(1), device=loss.device)).view(1, -1, 1)

        loss = reduce(loss, 'b ... -> b', 'mean')

        return loss #(B,)

    def expert_forward_together(self, global_cond, nactions, use_expert_i):


        trajectory = nactions

        if self.use_flow_matching:
            noise = self.sample_noise(trajectory.shape, trajectory.device)

            timesteps = self.sample_time(trajectory.shape[0], trajectory.device)

            time_expanded = timesteps[:, None, None]
            noisy_trajectory = time_expanded * noise + (1 - time_expanded) * trajectory
            target = noise - trajectory
        else:
            # Sample noise that we'll add to the images
            noise = torch.randn(trajectory.shape, device=trajectory.device)
            # input perturbation by adding additonal noise to alleviate exposure bias
            # reference: https://github.com/forever208/DDPM-IP
            noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, self.noise_scheduler.config.num_train_timesteps,
                (nactions.shape[0],), device=trajectory.device
            ).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_trajectory = self.noise_scheduler.add_noise(
                trajectory, noise_new, timesteps)

            pred_type = self.noise_scheduler.config.prediction_type
            if pred_type == 'epsilon':
                target = noise
            elif pred_type == 'sample':
                target = trajectory
            else:
                raise ValueError(f"Unsupported prediction type {pred_type}")


        pred = self.model(
            noisy_trajectory,
            timesteps,
            local_cond=None,
            global_cond=global_cond,
            use_expert_i=use_expert_i
        )


        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)

        if self.use_flow_matching:
            discount = 0.95
            loss = loss*(discount**torch.arange(loss.size(1), device=loss.device)).view(1, -1, 1)

        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()

        return loss

    # p(c)=sum(p(c|o)p(o))
    def log_cmp_m_ctxt_densities(self, log_gating_probs):
        # gating_probs: p(c|o)

        exp_arg = log_gating_probs + self._log_weights.to(log_gating_probs.device) #log(p(c|o)p(o))
        log_marg_ctxt_densities = torch.logsumexp(exp_arg, dim=1)
        return log_marg_ctxt_densities

    # p(o|c)=p(c|o)p(o)/p(c) 
    def log_resps(self, log_gating_probs):
        # gating_probs: p(c|o)

        log_marg_ctxt_densities = self.log_cmp_m_ctxt_densities(log_gating_probs) # log(p(c))

        log_gating_probs = log_gating_probs + self._log_weights.to(log_gating_probs.device) - log_marg_ctxt_densities[:, None]
        return log_gating_probs


    def sample_indices_per_expert(self, expert_gating: torch.Tensor):
        """
        expert_gating: (B, num_experts) each batch sample's expert sampling probability
        self.num_samples_per_expert: each expert's sampling number
        return: indices (num_samples_per_expert, num_experts)
        """
        B, num_experts = expert_gating.shape
        indices_per_expert = []

        for e in range(num_experts):
            probs = expert_gating[:, e]
            # multinomial without replacement
            sampled_indices = torch.multinomial(probs, self.num_samples_per_expert, replacement=False)
            indices_per_expert.append(sampled_indices)
        indices_per_expert=torch.stack(indices_per_expert,dim=0)
        return indices_per_expert #(num_experts, num_samples_per_expert)

    def gather_expert_data(self, batch: dict, indices: torch.Tensor):
        """
        batch: dict, contains 'obs' sub-dictionary and 'action'
        indices: (num_experts, num_samples_per_expert)
        return: dict, each tensor is concatenated in the order of experts
        """
        num_experts, num_samples_per_expert = indices.shape
        new_batch = {'obs': {}}

        # iterate over each key in obs
        for key, value in batch['obs'].items():
            # extract and concatenate in the order of experts
            selected = torch.cat([value[idx] for idx in indices], dim=0)
            new_batch['obs'][key] = selected

        # action
        new_batch['action'] = torch.cat([batch['action'][idx] for idx in indices], dim=0)

        return new_batch #(num_experts * num_samples_per_expert, ...)

    def compute_loss_together(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        # print("batch action:", batch['action'][0, 0, -4:])
        nactions = self.normalizer['action'].normalize(batch['action'])
        # print("nactions:", nactions[0, 0, -4:])
        assert self.obs_as_global_cond


        global_cond = self.obs_encoder(nobs)
            

        gating_logits = self.gating_network(global_cond)

        log_gating_probs = torch.log_softmax(gating_logits, dim=0) #(B,num_experts)


        # get prob dist for each expert
        expert_gating = torch.exp(log_gating_probs) 
        sampled_indices = self.sample_indices_per_expert(expert_gating) #(num_experts, num_samples_per_expert)
        
        global_cond_expert = torch.cat([global_cond[idx] for idx in sampled_indices], dim=0) #(num_experts*num_samples_per_expert, ...)
        nactions_expert = torch.cat([nactions[idx] for idx in sampled_indices], dim=0) #(num_experts*num_samples_per_expert, ...)
        
        assert global_cond_expert.shape[0] == self.num_experts * self.num_samples_per_expert
        assert nactions_expert.shape[0] == self.num_experts * self.num_samples_per_expert

        use_expert_i = torch.arange(self.num_experts).repeat_interleave(self.num_samples_per_expert).to(log_gating_probs.device)  #(num_experts*num_samples_per_expert, )

        expert_loss=self.expert_forward_together(global_cond_expert, nactions_expert, use_expert_i)

        # update gating network p(c|o)
        with torch.no_grad():
            log_resps=self.log_resps(log_gating_probs).detach() # log(p(o|c)) (B,num_experts)

        gating_loss=[]
        for expert_idx in range(self.num_experts):
            with torch.no_grad():
                loss=self.gating_forward(nactions=nactions, global_cond=global_cond, use_expert_i=expert_idx) #(B,)
            
            entropy = -log_gating_probs[:,expert_idx] #(B,)

            gating_loss.append((torch.exp(log_gating_probs[:,expert_idx])*(loss - self.beta * log_resps[:,expert_idx] - self.beta * entropy)).mean())

        gating_loss=torch.stack(gating_loss).mean()

        loss = expert_loss + self.gamma * gating_loss

        return loss, gating_loss, expert_loss


    def compute_loss_moe(self, batch, use_expert_i):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        # print("batch action:", batch['action'][0, 0, -4:])
        nactions = self.normalizer['action'].normalize(batch['action'])
        # print("nactions:", nactions[0, 0, -4:])
        assert self.obs_as_global_cond





        if self.encoder_nograd:
            #####################
            with torch.no_grad():
            #####################
                global_cond = self.obs_encoder(nobs)
        else:
            global_cond = self.obs_encoder(nobs)
            



        # train on multiple diffusion samples per obs
        if self.train_diffusion_n_samples != 1:
            # repeat obs features and actions multiple times along the batch dimension
            # each sample will later have a different noise sample, effecty training
            # more diffusion steps per each obs encoder forward pass
            global_cond = torch.repeat_interleave(global_cond,
                                                  repeats=self.train_diffusion_n_samples, dim=0)
            nactions = torch.repeat_interleave(nactions,
                                               repeats=self.train_diffusion_n_samples, dim=0)

        gating_logits = self.gating_network(global_cond)

        log_gating_probs = torch.log_softmax(gating_logits, dim=0) #(B,num_experts)



        expert_gating = torch.exp(log_gating_probs[:, use_expert_i])  #  (B,)
        
        categorical_dist = Categorical(expert_gating)
        
        sampled_indices = categorical_dist.sample((self.num_samples_per_expert,))  
        
        self.data_buffer.add(use_expert_i,{
            'camera0_rgb': batch['obs']['camera0_rgb'][sampled_indices],
            'robot0_eef_pos': batch['obs']['robot0_eef_pos'][sampled_indices],
            'robot0_eef_rot_axis_angle': batch['obs']['robot0_eef_rot_axis_angle'][sampled_indices],
            'robot0_gripper_width': batch['obs']['robot0_gripper_width'][sampled_indices],
            'robot0_language_instruction': batch['obs']['robot0_language_instruction'][sampled_indices],
            'robot0_eef_rot_axis_angle_wrt_start': batch['obs']['robot0_eef_rot_axis_angle_wrt_start'][sampled_indices],
            'action': batch['action'][sampled_indices]
        })
        

        sampled_data = self.data_buffer.get(use_expert_i)

        sampled_batch={
            'obs':{
                'camera0_rgb': sampled_data["camera0_rgb"],
                'robot0_eef_pos': sampled_data["robot0_eef_pos"],
                'robot0_eef_rot_axis_angle': sampled_data["robot0_eef_rot_axis_angle"],
                'robot0_gripper_width': sampled_data["robot0_gripper_width"],
                'robot0_language_instruction': sampled_data["robot0_language_instruction"],
                'robot0_eef_rot_axis_angle_wrt_start': sampled_data["robot0_eef_rot_axis_angle_wrt_start"],
            },
            'action': sampled_data["action"]
        }

        expert_loss=self.expert_forward(sampled_batch=sampled_batch,use_expert_i=use_expert_i)

        # update gating network p(c|o)
        # gating_loss=torch.zeros(1,device=images[0].device)
        with torch.no_grad():
            log_resps=self.log_resps(log_gating_probs).detach() # log(p(o|c)) (B,num_experts)

        # for expert_idx in range(self.config.num_experts):
        with torch.no_grad():
            loss=self.gating_forward(nactions=nactions, global_cond=global_cond, use_expert_i=use_expert_i) #(B,)
            
        entropy = -log_gating_probs[:,use_expert_i] #(B,)

        gating_loss=(torch.exp(log_gating_probs[:,use_expert_i])*(self.gamma * loss - self.beta * log_resps[:,use_expert_i] - self.beta * entropy)).mean()

        loss = gating_loss + expert_loss

        return loss, gating_loss, expert_loss


    def compute_loss_org(self, batch, use_expert_i):
        # normalize input
        assert 'valid_mask' not in batch


        nobs = self.normalizer.normalize(batch['obs'])
        # print("batch action:", batch['action'][0, 0, -4:])
        nactions = self.normalizer['action'].normalize(batch['action'])
        # print("nactions:", nactions[0, 0, -4:])
        assert self.obs_as_global_cond
        global_cond = self.obs_encoder(nobs)

        # train on multiple diffusion samples per obs
        if self.train_diffusion_n_samples != 1:
            # repeat obs features and actions multiple times along the batch dimension
            # each sample will later have a different noise sample, effecty training
            # more diffusion steps per each obs encoder forward pass
            global_cond = torch.repeat_interleave(global_cond,
                                                  repeats=self.train_diffusion_n_samples, dim=0)
            nactions = torch.repeat_interleave(nactions,
                                               repeats=self.train_diffusion_n_samples, dim=0)

        trajectory = nactions
        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        # input perturbation by adding additonal noise to alleviate exposure bias
        # reference: https://github.com/forever208/DDPM-IP
        noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps,
            (nactions.shape[0],), device=trajectory.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise_new, timesteps)

        # Predict the noise residual
        pred = self.model(
            noisy_trajectory,
            timesteps,
            local_cond=None,
            global_cond=global_cond
        )

        pred_type = self.noise_scheduler.config.prediction_type
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()

        return loss, torch.Tensor([0]), loss

    def compute_datasets_Z(self, batch):

        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])

        nactions = self.normalizer['action'].normalize(batch['action'])
        
        assert self.obs_as_global_cond

        global_cond = self.obs_encoder(nobs)
            



        # train on multiple diffusion samples per obs
        if self.train_diffusion_n_samples != 1:
            # repeat obs features and actions multiple times along the batch dimension
            # each sample will later have a different noise sample, effecty training
            # more diffusion steps per each obs encoder forward pass
            global_cond = torch.repeat_interleave(global_cond,
                                                  repeats=self.train_diffusion_n_samples, dim=0)
            nactions = torch.repeat_interleave(nactions,
                                               repeats=self.train_diffusion_n_samples, dim=0)

        gating_logits = self.gating_network(global_cond)

        gating_logits = torch.exp(gating_logits)

        Z = torch.sum(gating_logits, dim=0) #(num_experts,)
        
        return Z


    def forward(self, batch, use_expert_i):
        # if self.mode =="org":
        #     return self.compute_loss_org(batch, use_expert_i)

        # elif self.mode == "moe":
        # return self.compute_loss_moe(batch, use_expert_i)
        return self.compute_loss_together(batch)

        
        