from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.transformer_for_action_diffusion import TransformerForActionDiffusion
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.vision.transformer_obs_encoder import TransformerObsEncoder
from diffusion_policy.model.diffusion.moe import MoeArgs, MoeLayer
from diffusion_policy.model.diffusion.DiT_for_action_diffusion import *
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger(__name__)

class ExpertBuffer:
    def __init__(self, max_size=100, camera0_rgb_shape=(2, 3, 224, 224), robot0_eef_pos_shape=(2, 3), \
                 robot0_eef_rot_axis_angle_shape=(2, 6), robot0_gripper_width_shape=(2, 1), \
                 robot0_language_instruction_shape=(2, 100), robot0_eef_rot_axis_angle_wrt_start_shape=(2, 6),
                 action_shape=(16, 10)):
        self.max_size = max_size

        self.buffer = {
            'camera0_rgb': torch.zeros((max_size, *camera0_rgb_shape)).cuda(),
            'robot0_eef_pos': torch.zeros((max_size, *robot0_eef_pos_shape)).cuda(),
            'robot0_eef_rot_axis_angle': torch.zeros((max_size, *robot0_eef_rot_axis_angle_shape)).cuda(),
            'robot0_gripper_width': torch.zeros((max_size, *robot0_gripper_width_shape)).cuda(),
            'robot0_language_instruction': torch.zeros((max_size, *robot0_language_instruction_shape)).cuda(),
            'robot0_eef_rot_axis_angle_wrt_start': torch.zeros(
                (max_size, *robot0_eef_rot_axis_angle_wrt_start_shape)).cuda(),
            'action': torch.zeros((max_size, *action_shape)).cuda()
        }

        self.current_index = 0
        self.data_count = 0

    def add(self, new_data):
        """
        {'images': data, 'img_masks': data, 'lang_tokens': data, ...}
        """
        batch_size = new_data['camera0_rgb'].shape[0]

        for key in self.buffer:
            new_tensor = new_data[key]

            new_data_count = self.data_count + batch_size

            if new_data_count <= self.max_size:
                self.buffer[key][self.data_count:self.data_count + batch_size] = new_tensor
            else:
                start_idx = self.current_index
                end_idx = (start_idx + batch_size) % self.max_size

                if end_idx < start_idx:
                    self.buffer[key][start_idx:] = new_tensor[:self.max_size - start_idx]
                    self.buffer[key][:end_idx] = new_tensor[self.max_size - start_idx:]
                else:
                    self.buffer[key][start_idx:end_idx] = new_tensor

        self.data_count = min(new_data_count, self.max_size)

        self.current_index = (self.current_index + batch_size) % self.max_size

    def get(self):

        if self.data_count < self.max_size:
            indices = torch.multinomial(torch.ones(self.data_count), self.max_size, replacement=True)
            sampled_data = {key: self.buffer[key][indices] for key in self.buffer}
        else:
            sampled_data = self.buffer

        return sampled_data


class ExpertBuffers:
    def __init__(self, num_experts, max_size=100, camera0_rgb_shape=(2, 3, 224, 224), robot0_eef_pos_shape=(2, 3), \
                 robot0_eef_rot_axis_angle_shape=(2, 6), robot0_gripper_width_shape=(2, 1), \
                 robot0_language_instruction_shape=(2, 100), robot0_eef_rot_axis_angle_wrt_start_shape=(2, 6),
                 action_shape=(16, 10)):
        self.expert_buffers = [
            ExpertBuffer(max_size, camera0_rgb_shape, robot0_eef_pos_shape, robot0_eef_rot_axis_angle_shape,
                         robot0_gripper_width_shape, robot0_language_instruction_shape, \
                         robot0_eef_rot_axis_angle_wrt_start_shape, action_shape) for _ in range(num_experts)]

    def add(self, expert_idx, new_data):
        self.expert_buffers[expert_idx].add(new_data)

    def get(self, expert_idx):
        return self.expert_buffers[expert_idx].get()


class DiffusionTransformerTimmPolicy_di(BaseImagePolicy):
    def __init__(self,
                 shape_meta: dict,
                 noise_scheduler: DDPMScheduler,
                 obs_encoder: TransformerObsEncoder,
                 num_inference_steps=None,
                 input_pertub=0.1,
                 # arch
                 n_layer=7,
                 n_head=8,
                 n_emb=768,
                 p_drop_attn=0.1,
                 use_adaLN=False,
                 rms_norm=False,
                 num_experts=3,
                 di=True,
                 gamma=10,
                 beta=0.1,
                 fix_timestep=False,
                 encoder_nograd=False,
                 num_samples_per_expert=4,
                 buffer_max_size=16,
                 # parameters passed to step
                 **kwargs):
        super().__init__()

        # parse shapes
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        action_horizon = shape_meta['action']['horizon']

        obs_shape = obs_encoder.output_shape()
        assert obs_shape[-1] == n_emb
        obs_tokens = obs_shape[-2]

        self.use_di = False
        if di:
            self.use_di = True
            self.num_experts = num_experts
            self.gamma = gamma
            self.beta = beta
            self.fix_timestep = fix_timestep
            self.encoder_nograd = encoder_nograd
            self.num_samples_per_expert = num_samples_per_expert
            self.buffer_max_size = buffer_max_size

            self.gating_network = nn.Sequential(
                nn.Linear(obs_tokens * n_emb, 1024),
                nn.ReLU(),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_experts),
            )

            self._log_weights = -torch.log(torch.ones(self.num_experts) * self.num_experts)[None, :]

            self.data_buffer = ExpertBuffers(num_experts=self.num_experts,
                                             max_size=self.buffer_max_size,
                                             camera0_rgb_shape=(2, 3, 224, 224),
                                             robot0_eef_pos_shape=(2, 3),
                                             robot0_eef_rot_axis_angle_shape=(2, 6),
                                             robot0_gripper_width_shape=(2, 1),
                                             robot0_language_instruction_shape=(2, 100),
                                             robot0_eef_rot_axis_angle_wrt_start_shape=(2, 6),
                                             action_shape=(16, 10))

        moe = MoeArgs(num_shared_experts=0,
                      num_experts=self.num_experts,
                      num_experts_per_tok=1,
                      gate_feature_dim=n_emb)
        if use_adaLN:
            model = DiT_MoE_B_8_for_action(input_dim=action_dim,
                                       output_dim=action_dim,
                                       seq_len=action_horizon,
                                       max_cond_tokens=obs_tokens + 1,
                                       moe=moe)
            print(f"Use MoE DiT_B_for_action.")
        else:
            model = TransformerForActionDiffusion(
                input_dim=action_dim,
                output_dim=action_dim,
                action_horizon=action_horizon,
                n_layer=n_layer,
                n_head=n_head,
                n_emb=n_emb,
                max_cond_tokens=obs_tokens + 1,  # obs tokens + 1 token for time
                p_drop_attn=p_drop_attn,
                moe=moe
            )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.normalizer = LinearNormalizer()
        self.action_dim = action_dim
        self.action_horizon = action_horizon
        self.input_pertub = input_pertub
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        logger.info(
            "number of parameters: %e", sum(p.numel() for p in self.parameters())
        )

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def get_optimizer(
            self,
            lr: float,
            weight_decay: float,
            obs_encoder_lr: float,
            obs_encoder_weight_decay: float,
            betas: Tuple[float, float]
    ) -> torch.optim.Optimizer:
        optim_groups = self.model.get_optim_groups(
            weight_decay=weight_decay)

        backbone_params = list()
        other_obs_params = list()
        for key, value in self.obs_encoder.named_parameters():
            if key.startswith('key_model_map'):
                backbone_params.append(value)
            else:
                other_obs_params.append(value)
        optim_groups.append({
            "params": backbone_params,
            "weight_decay": obs_encoder_weight_decay,
            "lr": obs_encoder_lr,  # for fine tuning
            'initial_lr': obs_encoder_lr
        })
        optim_groups.append({
            "params": other_obs_params,
            "weight_decay": weight_decay,
            'initial_lr': lr
        })
        optimizer = torch.optim.AdamW(
            optim_groups, lr=lr, betas=betas
        )
        return optimizer

    def expert_forward(self, batch, use_expert_i):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        trajectory = nactions

        # for key, value in batch['obs'].items():
        #     print(f"{key}: {value.shape}")

        # process input
        obs_tokens = self.obs_encoder(nobs)
        # (B, N, n_emb)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        # input perturbation by adding additonal noise to alleviate exposure bias
        # reference: https://github.com/forever208/DDPM-IP
        noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps,
            (nactions.shape[0],), device=trajectory.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise_new, timesteps)

        # Predict the noise residual
        pred = self.model(
            noisy_trajectory,
            timesteps,
            use_expert_i=use_expert_i,
            cond=obs_tokens
        )

        pred_type = self.noise_scheduler.config.prediction_type
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()

        return loss

    def gating_forward(self, nactions, global_cond, use_expert_i):
        # normalize input
        trajectory = nactions

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        # input perturbation by adding additonal noise to alleviate exposure bias
        # reference: https://github.com/forever208/DDPM-IP
        noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps,
            (nactions.shape[0],), device=trajectory.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise_new, timesteps)

        # Predict the noise residual
        pred = self.model(
            noisy_trajectory,
            timesteps,
            use_expert_i=use_expert_i,
            cond=global_cond
        )

        pred_type = self.noise_scheduler.config.prediction_type
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()

        return loss

    # p(c)=sum(p(c|o)p(o))
    def log_cmp_m_ctxt_densities(self, log_gating_probs):
        # gating_probs: p(c|o)

        exp_arg = log_gating_probs + self._log_weights.to(log_gating_probs.device) #log(p(c|o)p(o))
        log_marg_ctxt_densities = torch.logsumexp(exp_arg, dim=1)
        return log_marg_ctxt_densities

    # p(o|c)=p(c|o)p(o)/p(c)
    def log_resps(self, log_gating_probs):
        # gating_probs: p(c|o)

        log_marg_ctxt_densities = self.log_cmp_m_ctxt_densities(log_gating_probs) # log(p(c))

        log_gating_probs = log_gating_probs + self._log_weights.to(log_gating_probs.device) - log_marg_ctxt_densities[:, None]

        return log_gating_probs

    def compute_loss(self, batch, use_expert_i):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        B = nactions.shape[0]

        if self.encoder_nograd:
            #####################
            with torch.no_grad():
            #####################
                global_cond = self.obs_encoder(nobs)
        else:
            global_cond = self.obs_encoder(nobs)

        # print("global_cond shape:", global_cond.shape, global_cond.view(B, -1).shape)

        gating_logits = self.gating_network(global_cond.view(B, -1)) #

        log_gating_probs = torch.log_softmax(gating_logits, dim=0)  # (B, num_experts)  log p(c | o)

        expert_gating = torch.exp(log_gating_probs[:, use_expert_i])  # 

        categorical_dist = Categorical(expert_gating)

        sampled_indices = categorical_dist.sample((self.num_samples_per_expert,))  

        self.data_buffer.add(use_expert_i, {
            'camera0_rgb': batch['obs']['camera0_rgb'][sampled_indices],
            'robot0_eef_pos': batch['obs']['robot0_eef_pos'][sampled_indices],
            'robot0_eef_rot_axis_angle': batch['obs']['robot0_eef_rot_axis_angle'][sampled_indices],
            'robot0_gripper_width': batch['obs']['robot0_gripper_width'][sampled_indices],
            'robot0_language_instruction': batch['obs']['robot0_language_instruction'][sampled_indices],
            'robot0_eef_rot_axis_angle_wrt_start': batch['obs']['robot0_eef_rot_axis_angle_wrt_start'][sampled_indices],
            'action': batch['action'][sampled_indices]
        })

        sampled_data = self.data_buffer.get(use_expert_i)

        sampled_batch = {
            'obs': {
                'camera0_rgb': sampled_data["camera0_rgb"],
                'robot0_eef_pos': sampled_data["robot0_eef_pos"],
                'robot0_eef_rot_axis_angle': sampled_data["robot0_eef_rot_axis_angle"],
                'robot0_gripper_width': sampled_data["robot0_gripper_width"],
                'robot0_language_instruction': sampled_data["robot0_language_instruction"],
                'robot0_eef_rot_axis_angle_wrt_start': sampled_data["robot0_eef_rot_axis_angle_wrt_start"],
            },
            'action': sampled_data["action"]
        }

        expert_loss = self.expert_forward(batch=dict_apply(sampled_batch, lambda x: x.to(nactions.device, non_blocking=True)), 
                                          use_expert_i=use_expert_i)

        # update gating network p(c|o)
        # gating_loss=torch.zeros(1,device=images[0].device)
        with torch.no_grad():
            log_resps = self.log_resps(log_gating_probs).detach()  # input: log(p(o|c)) output: log(p(o|c)) (B, num_experts) p(o|c)

        # for expert_idx in range(self.config.num_experts):
        with torch.no_grad():
            loss = self.gating_forward(nactions=nactions, global_cond=global_cond, use_expert_i=use_expert_i)  # (B,)

        entropy = -log_gating_probs[:, use_expert_i]  # (B,)

        gating_loss = (torch.exp(log_gating_probs[:, use_expert_i]) * (
                    self.gamma * loss - self.beta * log_resps[:, use_expert_i] - self.beta * entropy)).mean()

        loss = gating_loss + expert_loss

        return loss, gating_loss, expert_loss

    def forward(self, batch, use_expert_i):
        return self.compute_loss(batch, use_expert_i)