from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.transformer_for_action_diffusion import TransformerForActionDiffusion
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.vision.transformer_obs_encoder import TransformerObsEncoder
from diffusion_policy.model.diffusion.moe import MoeArgs, MoeLayer
from diffusion_policy.model.diffusion.DiT_for_action_diffusion import DiT_MoE_B_8_for_action
import re

# 动作优先级定义（从高到低）
action_keywords = [
    ("pour", ["pour"]),
    ("unplug", ["unplug", "pull out"]),
    ("open", ["open"]),
    ("close", ["close"]),
    ("fold", ["fold"]),
    ("rearrange", ["rearrange"]),
    ("throw", ["throw"]),
    ("spin", ["spin"]),
    ("pick", ["pick up", "pick", "grasp", "grab"]),
    ("place", ["place", "put"]),
    ("", [""]),
]

def classify_action(task: str) -> int:
    task_lower = task.lower()
    for i, (action, keywords) in enumerate(action_keywords):
        for kw in keywords:
            if re.search(rf'\b{re.escape(kw)}\b', task_lower):
                return i
    return len(action_keywords) - 1

class DiffusionTransformerTimmPolicy(BaseImagePolicy):
    def __init__(self,
                 shape_meta: dict,
                 noise_scheduler: DDPMScheduler,
                 obs_encoder: TransformerObsEncoder,
                 num_inference_steps=None,
                 input_pertub=0.1,
                 # arch
                 n_layer=7,
                 n_head=8,
                 n_emb=768,
                 p_drop_attn=0.1,
                 use_adaLN=False,
                 rms_norm=False,
                 moe_num_shared_experts=0,
                 moe_num_experts=12,
                 moe_num_experts_per_tok=1,
                 # parameters passed to step
                 **kwargs):
        super().__init__()

        # parse shapes
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        action_horizon = shape_meta['action']['horizon']

        obs_shape = obs_encoder.output_shape()
        assert obs_shape[-1] == n_emb
        obs_tokens = obs_shape[-2]

        moe = MoeArgs(num_shared_experts=moe_num_shared_experts, 
                        num_experts=moe_num_experts, 
                        num_experts_per_tok=moe_num_experts_per_tok, 
                        gate_feature_dim=n_emb)

        if use_adaLN:
            model = DiT_MoE_B_8_for_action(input_dim=action_dim,
                                       output_dim=action_dim,
                                       seq_len=action_horizon,
                                       max_cond_tokens=obs_tokens + 1,
                                       moe=moe)
            print(f"Use MoE DiT_B_for_action.{moe_num_shared_experts}-{moe_num_experts}-{moe_num_experts_per_tok}")
        else:
            model = TransformerForActionDiffusion(
                input_dim=action_dim,
                output_dim=action_dim,
                action_horizon=action_horizon,
                n_layer=n_layer,
                n_head=n_head,
                n_emb=n_emb,
                max_cond_tokens=obs_tokens + 1,  # obs tokens + 1 token for time
                p_drop_attn=p_drop_attn,
                moe=moe
            )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.normalizer = LinearNormalizer()
        self.action_dim = action_dim
        self.action_horizon = action_horizon
        self.input_pertub = input_pertub
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

    # ========= inference  ============
    def conditional_sample(self,
                           condition_data, condition_mask,
                           cond=None, generator=None, use_expert_i=None,
                           # keyword arguments to scheduler.step
                           **kwargs
                           ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape,
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)

        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output, _ = model(trajectory, t, cond, use_expert_i)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory,
                generator=generator,
                **kwargs
            ).prev_sample

        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]

        return trajectory

    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict  # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        B = next(iter(nobs.values())).shape[0]

        use_expert_i = torch.zeros(obs_dict['robot0_language_instruction'].shape[0]).to(torch.long).to(self.device)
        for i, task in enumerate(obs_dict["robot0_language_instruction"]):
            # (B, N, language_dim) -> (B,)
            task_ = ''.join([chr(c) for c in task[0].cpu().numpy() if c != 0]).strip()
            # print("task: ", task_, len(task_))
            use_expert_i[i] = classify_action(task_)

        # print("use_expert_i:", use_expert_i)
        # process input
        obs_tokens = self.obs_encoder(nobs)
        # (B, N, n_emb)

        # empty data for action
        cond_data = torch.zeros(size=(B, self.action_horizon, self.action_dim), device=self.device, dtype=self.dtype)
        cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

        # run sampling
        nsample = self.conditional_sample(
            condition_data=cond_data,
            condition_mask=cond_mask,
            cond=obs_tokens,
            use_expert_i=use_expert_i,
            **self.kwargs)

        # unnormalize prediction
        assert nsample.shape == (B, self.action_horizon, self.action_dim)
        action_pred = self.normalizer['action'].unnormalize(nsample)

        result = {
            'action': action_pred,
            'action_pred': action_pred
        }
        return result

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def get_optimizer(
            self,
            lr: float,
            weight_decay: float,
            obs_encoder_lr: float,
            obs_encoder_weight_decay: float,
            betas: Tuple[float, float]
    ) -> torch.optim.Optimizer:
        optim_groups = self.model.get_optim_groups(
            weight_decay=weight_decay)

        backbone_params = list()
        other_obs_params = list()
        for key, value in self.obs_encoder.named_parameters():
            if key.startswith('key_model_map'):
                backbone_params.append(value)
            else:
                other_obs_params.append(value)
        optim_groups.append({
            "params": backbone_params,
            "weight_decay": obs_encoder_weight_decay,
            "lr": obs_encoder_lr,  # for fine tuning
            'initial_lr': obs_encoder_lr
        })
        optim_groups.append({
            "params": other_obs_params,
            "weight_decay": weight_decay, 
            'initial_lr': lr
        })
        optimizer = torch.optim.AdamW(
            optim_groups, lr=lr, betas=betas
        )
        return optimizer

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        trajectory = nactions

        # for key, value in batch['obs'].items():
        #     print(f"{key}: {value.shape}")

        use_expert_i = torch.zeros(batch['action'].shape[0]).to(torch.long).to(trajectory.device)
        for i, task in enumerate(batch['obs']["robot0_language_instruction"]):
            # (B, N, language_dim) -> (B,)
            task_ = ''.join([chr(c) for c in task[0].cpu().numpy() if c != 0])
            # print("task: ", task_)
            use_expert_i[i] = classify_action(task_)

        # process input
        obs_tokens = self.obs_encoder(nobs)
        # (B, N, n_emb)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        # input perturbation by adding additonal noise to alleviate exposure bias
        # reference: https://github.com/forever208/DDPM-IP
        noise_new = noise + self.input_pertub * torch.randn(trajectory.shape, device=trajectory.device)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps,
            (nactions.shape[0],), device=trajectory.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise_new, timesteps)

        # Predict the noise residual
        pred, gate_logits_list = self.model(
            noisy_trajectory,
            timesteps,
            use_expert_i=use_expert_i,
            cond=obs_tokens
        )

        pred_type = self.noise_scheduler.config.prediction_type
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()

        if self.model.use_moe is not None:
            # gate_logits_list(B,layer_num,experts)
            B, L, E = gate_logits_list.shape

            # Expand use_expert_i to (B, L) to match gate_logits_list
            targets = use_expert_i.unsqueeze(1).expand(-1, L)  # shape: (B, L)

            # Flatten inputs for loss computation
            logits_flat = gate_logits_list.view(B * L, E)
            targets_flat = targets.reshape(B * L)

            # Cross entropy loss: applies LogSoftmax + NLLLoss
            router_loss = F.cross_entropy(logits_flat, targets_flat)

            return loss + router_loss

        return loss

    def forward(self, batch):
        return self.compute_loss(batch)
