# Modified DiT model from models.py to support action generation instead of images
# Input shape: (B, T, input_dim), Output shape: (B, T, output_dim)

import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import Attention, Mlp
import torch.nn.functional as F
from diffusion_policy.model.diffusion.moe import MoeArgs, MoeLayer1
from typing import Optional, Any, Union, Callable

class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout : float = 0.1, bias: bool = True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()

        self.linear1 = nn.Linear(dim, hidden_dim, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(hidden_dim, dim, bias=bias, **factory_kwargs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear2(self.dropout(F.gelu(self.linear1(x))))
        return self.dropout2(x)

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

def get_1d_sincos_pos_embed(embed_dim, seq_len):
    pos = np.arange(seq_len, dtype=np.float32)
    return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        return self.mlp(t_freq)

class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, moe: Optional[MoeArgs] = None, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
        self.use_moe = moe is not None

        # self.dropout = nn.Dropout(dropout)
        # self.dropout_mlp = nn.Dropout(dropout)

        if self.use_moe:
            self.moe = MoeLayer1(
                shared_experts=[FeedForward(hidden_size, int(hidden_size * mlp_ratio), dropout=0.) for _ in range(moe.num_shared_experts)],
                unshared_experts=[FeedForward(hidden_size, int(hidden_size * mlp_ratio), dropout=0.) for _ in range(moe.num_experts)],
                # gate=nn.Linear(hidden_size, moe.num_experts, bias=False),
                num_experts_per_tok=moe.num_experts_per_tok,
            )
        else:
            self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=nn.GELU)

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size)
        )

    def forward(self, x, c, use_expert_i=None):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))

        normed = modulate(self.norm2(x), shift_mlp, scale_mlp)
        if self.use_moe:
            moe_out = self.moe(normed, use_expert_i=use_expert_i)
            x = x + gate_mlp.unsqueeze(1) * moe_out
            return x
        else:
            x = x + gate_mlp.unsqueeze(1) * self.mlp(normed)
            return x

class FinalLayer(nn.Module):
    def __init__(self, hidden_size, out_dim):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.linear = nn.Linear(hidden_size, out_dim)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm(x), shift, scale)
        return self.linear(x)

class DiTForActions(nn.Module):
    def __init__(self, seq_len=32, input_dim=20, output_dim=20, max_cond_tokens=11,
                 hidden_size=768, depth=12, num_heads=12, mlp_ratio=4.0, moe: Optional[MoeArgs] = None):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_size)
        print("max_cond_tokens:", max_cond_tokens)
        self.condition_proj = nn.Linear(hidden_size * max_cond_tokens, hidden_size)  # for timestep + condition
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, hidden_size), requires_grad=True)
        pos_emb_np = get_1d_sincos_pos_embed(hidden_size, seq_len)
        with torch.no_grad():
            self.pos_embed.copy_(torch.from_numpy(pos_emb_np).float().unsqueeze(0))

        self.use_moe = moe is not None
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio, moe=moe) for _ in range(depth)
        ])
        self.final_layer = FinalLayer(hidden_size, output_dim)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def get_optim_groups(self, weight_decay: float = 1e-3):
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (nn.Linear,)
        blacklist_weight_modules = (nn.LayerNorm,)

        param_dict = {}
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters(recurse=False):
                fpn = f"{mn}.{pn}" if mn else pn
                param_dict[fpn] = p
                if pn.endswith("bias"):
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        # include top-level parameters
        for pn, p in self.named_parameters():
            if pn not in param_dict:
                param_dict[pn] = p
                if 'pos_embed' in pn:
                    no_decay.add(pn)
                elif pn.endswith("bias"):
                    no_decay.add(pn)
                elif pn.endswith("weight"):
                    decay.add(pn)
                else:
                    no_decay.add(pn)
        no_decay.add('pos_embed')
        inter_params = decay & no_decay
        union_params = decay | no_decay

        assert len(inter_params) == 0, f"Parameters {inter_params} in both decay and no_decay"
        assert len(param_dict.keys() - union_params) == 0, f"Parameters {param_dict.keys() - union_params} were not categorized"

        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        return optim_groups

    def forward(self, x, t, cond, use_expert_i=None):
        timesteps = t
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(x.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(x.shape[0])

        x = self.input_proj(x) + self.pos_embed[:, :x.shape[1], :]
        t_emb = self.t_embedder(timesteps)  # (B, hidden_size)

        if cond.ndim == 3:
            B = cond.shape[0]
            condition_emb = cond.view(B, -1)
        # print("t_emb shape:", t_emb.shape)
        # print("condition_emb shape:", condition_emb.shape)
        c = self.condition_proj(torch.cat([t_emb, condition_emb], dim=-1))

        for block in self.blocks:
            if self.use_moe:
                x = block(x, c, use_expert_i=use_expert_i)
                # gate_logits_list.append(gate_logits)
            else:
                x, _ = block(x, c)

        x = self.final_layer(x, c)

        # if self.use_moe:
        #     return x, torch.stack(gate_logits_list, dim=1)
        return x

# Configuration

def DiT_B_12_for_action(**kwargs):
    print("Use DiT_B_12_for_action.")
    input_dim = kwargs.pop('input_dim', 20)
    output_dim = kwargs.pop('output_dim', 20)
    seq_len = kwargs.pop('seq_len', 16)
    max_cond_tokens = kwargs.pop('max_cond_tokens', 11)
    return DiTForActions(depth=12, hidden_size=768, seq_len=seq_len,
                         input_dim=input_dim,
                         output_dim=output_dim,
                         max_cond_tokens=max_cond_tokens,
                         num_heads=12, **kwargs)

def DiT_B_8_for_action(**kwargs):
    print("Use DiT_B_8_for_action.")
    input_dim = kwargs.pop('input_dim', 20)
    output_dim = kwargs.pop('output_dim', 20)
    seq_len = kwargs.pop('seq_len', 16)
    max_cond_tokens = kwargs.pop('max_cond_tokens', 11)
    return DiTForActions(depth=8, hidden_size=768, seq_len=seq_len,
                         input_dim=input_dim,
                         output_dim=output_dim,
                         max_cond_tokens=max_cond_tokens,
                         num_heads=8, **kwargs)


def DiT_B_6_for_action(**kwargs):
    print("Use DiT_B_6_for_action.")
    input_dim = kwargs.pop('input_dim', 20)
    output_dim = kwargs.pop('output_dim', 20)
    seq_len = kwargs.pop('seq_len', 16)
    max_cond_tokens = kwargs.pop('max_cond_tokens', 11)
    return DiTForActions(depth=6, hidden_size=768, seq_len=seq_len,
                         input_dim=input_dim,
                         output_dim=output_dim,
                         max_cond_tokens=max_cond_tokens,
                         num_heads=6, **kwargs)

def DiT_B_4_for_action(**kwargs):
    print("Use DiT_B_4_for_action.")
    input_dim = kwargs.pop('input_dim', 20)
    output_dim = kwargs.pop('output_dim', 20)
    seq_len = kwargs.pop('seq_len', 16)
    max_cond_tokens = kwargs.pop('max_cond_tokens', 11)
    return DiTForActions(depth=4, hidden_size=768, seq_len=seq_len,
                         input_dim=input_dim,
                         output_dim=output_dim,
                         max_cond_tokens=max_cond_tokens,
                         num_heads=4, **kwargs)

def DiT_MoE_B_8_for_action(**kwargs):
    print("Use DiT_MoE_B_8_for_action.")
    input_dim = kwargs.pop('input_dim', 20)
    output_dim = kwargs.pop('output_dim', 20)
    seq_len = kwargs.pop('seq_len', 16)
    max_cond_tokens = kwargs.pop('max_cond_tokens', 11)
    moe = kwargs.pop('moe', None)
    return DiTForActions(depth=8, hidden_size=768, seq_len=seq_len,
                         input_dim=input_dim,
                         output_dim=output_dim,
                         max_cond_tokens=max_cond_tokens,
                         num_heads=8,
                         moe=moe, **kwargs)
