import torch
from einops import rearrange
from torch import layer_norm, nn
import torch.nn.functional as F


class CausalPatchEmbed(nn.Module):
    def __init__(self, length=20, patch_size=4, stride=2, in_dim=2, embed_dim=64):
        super().__init__()
        self.patch_size = patch_size
        self.stride = stride
        self.in_dim = in_dim
        self.proj = nn.Linear(patch_size * in_dim, embed_dim)


        self.pos_embed = nn.Parameter(torch.randn(1, (length - patch_size) // stride + 1, embed_dim))

    def forward(self, x):
        B, T, _ = x.shape

        patches = x.unfold(1, self.patch_size, self.stride)  # patch number: N = [(L - P) / S] + 1
        patches = patches.permute(0, 1, 3, 2)
        patches = patches.reshape(B, -1, self.patch_size * self.in_dim)


        patch_emb = self.proj(patches)
        patch_emb += self.pos_embed

        return patch_emb


class DynamicPatchEmbedding(nn.Module):
    def __init__(self, in_dim: int = 2, obs_len: int = 8, latent_dim: int = 64,
                 patch_list: list = [2, 4, 8], num_experts: int = 4, top_k: int = 2,
                 mask_mode: str = 'feature', mask_ratio: float = 0.3):
        super().__init__()
        self.in_dim = in_dim
        self.obs_len = obs_len
        self.latent_dim = latent_dim
        self.patch_list = patch_list
        self.num_experts = num_experts
        self.top_k = min(top_k, num_experts)
        self.mask_mode = mask_mode
        self.mask_ratio = mask_ratio

        # Multi-scale patch projection layers (independent parameters for each expert)
        self.expert_proj = nn.ModuleDict({
            f'size{ps}': nn.ModuleList([
                nn.Linear(in_dim * ps, latent_dim)
                for _ in range(num_experts)
            ]) for ps in patch_list
        })

        # Expert-specific positional encodings
        self.pos_embeds = nn.ParameterDict({
            f'size{ps}': nn.ParameterList([
                nn.Parameter(torch.randn(obs_len // ps, latent_dim))
                for _ in range(num_experts)
            ]) for ps in patch_list
        })

        # Gating network (dynamic routing)
        self.gate_net = nn.Sequential(
            nn.Linear(in_dim * obs_len, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, num_experts * len(patch_list)),
            nn.Softmax(dim=-1)
        )

        # # Expert output fusion layer
        # self.fusion = nn.Linear(latent_dim * num_experts, latent_dim)

        # Use FPN to fuse multi-scale embedded features
        self.fpn_lateral = nn.ModuleDict({
            f'size{ps}': nn.Conv1d(latent_dim, latent_dim, kernel_size=1)
            for ps in patch_list
        })

        self.fpn_output = nn.ModuleDict({
            f'size{ps}': nn.Sequential(
                nn.Conv1d(latent_dim, latent_dim, kernel_size=3, padding=1),
                nn.ReLU()
            )
            for ps in patch_list
        })

        # Upsampling and downsampling layers
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.downsample = nn.MaxPool1d(kernel_size=2, stride=2)

    def forward(self, x):
        B = x.shape[0]

        # Compute gating weights
        global_context = rearrange(x, 'b t d -> b (t d)')
        gates = self.gate_net(global_context).view(B, -1, len(self.patch_list))

        # Store original features for each scale
        scale_features = {ps: [] for ps in self.patch_list}

        # Each expert independently processes multi-scale features
        for expert_id in range(self.num_experts):
            expert_weight = gates[:, expert_id]  # [B, S]

            for scale_idx, ps in enumerate(self.patch_list):
                if expert_weight[:, scale_idx].sum() == 0:
                    continue

                # Split sequence and project
                num_patches = self.obs_len // ps
                patches = rearrange(x, 'b (n p) d -> b n (p d)', p=ps, n=num_patches)
                proj = self.expert_proj[f'size{ps}'][expert_id]
                pos = self.pos_embeds[f'size{ps}'][expert_id]
                embed = proj(patches) + pos[:num_patches]  # [B, num_patches, latent_dim]

                # Apply gating weights and store
                weight = gates[:, expert_id, scale_idx].view(B, 1, 1)
                scale_features[ps].append(weight * embed)

        # Aggregate expert outputs for each scale
        fpn_features = {}
        for ps in self.patch_list:
            if scale_features[ps]:
                # Aggregate features from all experts at current scale
                aggregated = sum(scale_features[ps]) / len(scale_features[ps])  # [B, num_patches, latent_dim]
                # Convert to format required by Conv1d [B, C, L]
                fpn_features[ps] = aggregated.permute(0, 2, 1)  # [B, latent_dim, num_patches]
            else:
                fpn_features[ps] = torch.zeros(B, self.latent_dim, self.obs_len // ps, device=x.device)

        # Build FPN pyramid (sorted from fine to coarse)
        sorted_scales = sorted(self.patch_list, reverse=True)  # e.g., [8,4,2]

        # Initialize FPN processing
        pyramid = {}
        for i, ps in enumerate(sorted_scales):
            # Lateral connection
            lateral = self.fpn_lateral[f'size{ps}'](fpn_features[ps])

            if i == 0:
                # Finest scale directly becomes the top of the pyramid
                pyramid[ps] = self.fpn_output[f'size{ps}'](lateral)
            else:
                # Upsample previous layer's features and add
                prev_ps = sorted_scales[i - 1]
                upsampled = self.upsample(pyramid[prev_ps])
                # Handle potential size mismatches
                if upsampled.shape[-1] > lateral.shape[-1]:
                    upsampled = upsampled[..., :lateral.shape[-1]]
                elif upsampled.shape[-1] < lateral.shape[-1]:
                    pad = lateral.shape[-1] - upsampled.shape[-1]
                    upsampled = F.pad(upsampled, (0, pad))

                pyramid[ps] = self.fpn_output[f'size{ps}'](lateral + upsampled)

        # Top-down feature enhancement
        enhanced_features = {}
        for i, ps in enumerate(reversed(sorted_scales)):
            if i == 0:
                # Coarsest scale directly uses pyramid features
                enhanced_features[ps] = pyramid[ps]
            else:
                # Get enhanced features from coarser scale
                coarser_ps = sorted_scales[len(sorted_scales) - i]  # previous (coarser) scale
                upsampled = self.upsample(enhanced_features[coarser_ps])

                # Adjust size
                target_length = pyramid[ps].shape[-1]
                if upsampled.shape[-1] > target_length:
                    upsampled = upsampled[..., :target_length]
                elif upsampled.shape[-1] < target_length:
                    pad = target_length - upsampled.shape[-1]
                    upsampled = F.pad(upsampled, (0, pad))

                # Fuse features
                enhanced_features[ps] = pyramid[ps] + upsampled

        # Select enhanced features from the finest scale as output
        finest_scale = min(self.patch_list)
        output = enhanced_features[finest_scale].permute(0, 2, 1)  # [B, num_patches, latent_dim]

        moe_outputs = {
            'gates': gates,
            'fpn_features': enhanced_features
        }

        return output, moe_outputs