import torch
import torch.nn as nn
import torch.nn.functional as F
# from einops.layers.torch import Rearrange

class AddAuxiliaryLoss(torch.autograd.Function):
    """
    The trick function of adding auxiliary (aux) loss, 
    which includes the gradient of the aux loss during backpropagation.
    """
    @staticmethod
    def forward(ctx, x, loss):
        assert loss.numel() == 1
        ctx.dtype = loss.dtype
        ctx.required_aux_loss = loss.requires_grad
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8, num_experts=None, vanilla_moe=False):
        super().__init__()


        # print("num_experts, vanilla_moe")
        # print(num_experts,vanilla_moe)
        if num_experts is not None:
            self.block = nn.ModuleList([nn.Sequential(
                nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
                # Rearrange('batch channels horizon -> batch channels 1 horizon'),
                nn.GroupNorm(n_groups, out_channels),
                # Rearrange('batch channels 1 horizon -> batch channels horizon'),
                nn.Mish(),
            ) for i in range(num_experts)])

            self.num_experts=num_experts

            if vanilla_moe:
                self.mode = "vanilla_moe"
                self.topk = 1
                self.alpha = 0.01
                self.gate = nn.Linear(inp_channels, num_experts, bias=False)
            else:
                self.mode= "moe" # "vanilla_moe"
            

               

        else:
            self.block = nn.Sequential(
                nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
                # Rearrange('batch channels horizon -> batch channels 1 horizon'),
                nn.GroupNorm(n_groups, out_channels),
                # Rearrange('batch channels 1 horizon -> batch channels horizon'),
                nn.Mish(),
            )
            self.num_experts=None

    # def switch_to_moe(self):
    #     if self.num_experts is not None:
    #         self.mode="moe"

    #         with torch.no_grad():
    #             for i in range(1, self.num_experts):
    #                 self.block[i].load_state_dict(self.block[0].state_dict())


    def forward(self, x, use_expert_i=None):

        if self.num_experts is not None:
            if self.mode == "vanilla_moe": # use the vanilla_moe and apply balancing aux loss
                results = torch.zeros_like(x)

                gate_logits = self.gate(x.mean(dim=-1))
                scores = F.softmax(gate_logits, dim=1, dtype=torch.float).to(x.dtype)

                weights, selected_experts = torch.topk(scores, self.topk, dim=-1)

                scores_for_aux = scores
                topk_idx_for_aux_loss = selected_experts
                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.num_experts)
                ce = mask_ce.float().mean(0)
                Pi = scores_for_aux.mean(0)
                fi = ce * self.num_experts
                aux_loss = (Pi * fi).sum() * self.alpha

                for i, expert in enumerate(self.block):
                    batch_idx, nth_expert = torch.where(selected_experts == i)

                    if batch_idx.shape[0]==0:
                        continue
                    
                    w=weights[batch_idx, nth_expert, None, None]
                    e=expert(x[batch_idx])
                    results[batch_idx] += w * e

                results = AddAuxiliaryLoss.apply(results, aux_loss)

            elif self.mode == "moe":
                # print("use expert i:" , use_expert_i)
                results = torch.zeros_like(x)
                assert use_expert_i is not None
                weights, selected_experts = torch.topk(F.one_hot(use_expert_i,num_classes=self.num_experts), 1, dim=-1) #(B, 1)


                for i, expert in enumerate(self.block):
                    batch_idx, nth_expert = torch.where(selected_experts == i)

                    if batch_idx.shape[0]==0:
                        continue
                    
                    w=weights[batch_idx, nth_expert, None, None]
                    e=expert(x[batch_idx])
                    results[batch_idx] += w * e

            # elif self.mode == "org": # only use the first expert
            #     results = self.block[0](x)

            return results

        else:
            return self.block(x)


def test():
    cb = Conv1dBlock(256, 128, kernel_size=3)
    x = torch.zeros((1,256,16))
    o = cb(x)
