# Copyright (c) OpenMMLab. All rights reserved.
import os
import math
from typing import Sequence
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.utils import digit_version
from mmengine.registry import MODELS

from mmpretrain.models.utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer,
                                     resize_pos_embed, to_2tuple)
from mmpretrain.models import VisionTransformer

if digit_version(torch.__version__) < digit_version('1.8.0'):
    floor_div = torch.floor_divide
else:
    floor_div = partial(torch.div, rounding_mode='floor')


class AddAuxiliaryLoss(torch.autograd.Function):
    """
    The trick function of adding auxiliary (aux) loss, 
    which includes the gradient of the aux loss during backpropagation.
    """
    @staticmethod
    def forward(ctx, x, loss):
        if loss is not None:
            assert loss.numel() == 1
            ctx.dtype = loss.dtype
            ctx.required_aux_loss = loss.requires_grad
        else:
            ctx.required_aux_loss = False
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss


class MoEGate(nn.Module):
    def __init__(self,
                 hidden_size,
                 n_routed_experts,
                 num_experts_per_tok,
                 norm_topk_prob=False,
                 aux_loss_alpha=0.001,
                 use_seq_aux=True,
                 use_dist_aux=False) -> None:
        super().__init__()
        self.top_k = num_experts_per_tok
        self.n_routed_experts = n_routed_experts

        self.alpha = aux_loss_alpha
        self.use_seq_aux = use_seq_aux
        self.use_dist_aux = use_dist_aux

        # topk selection algorithm
        self.norm_topk_prob = norm_topk_prob
        self.gating_dim = hidden_size
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        import torch.nn.init as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        bsz, seq_len, h = hidden_states.shape
        ### compute gating score
        hidden_states = hidden_states.view(-1, h)
        logits = F.linear(hidden_states, self.weight, None)
        scores = logits.softmax(dim=-1)

        ### select top-k experts
        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        ### norm gate to sum 1
        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-6
            topk_weight = topk_weight / denominator

        ### expert-level computation auxiliary loss
        if self.training and self.alpha > 0.0:
            scores_for_aux = scores
            aux_topk = self.top_k
            # always compute aux loss based on the naive greedy topk method
            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
            if self.use_seq_aux:
                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
                ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
                ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
            else:
                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
                ce = mask_ce.float().mean(0)

                # Distributed computation of auxiliary loss
                if self.use_dist_aux and dist.is_initialized():
                    dist.all_reduce(ce, op=dist.ReduceOp.SUM)
                    ce = ce / dist.get_world_size()

                Pi = scores_for_aux.mean(0)
                fi = ce * self.n_routed_experts
                aux_loss = (Pi * fi).sum() * self.alpha
        else:
            aux_loss = None
        return topk_idx, topk_weight, aux_loss


class MoETransformerEncoderLayer(BaseModule):
    """Implements one encoder layer in MoE Vision Transformer."""

    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 layer_scale_init_value=0.,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 moe_type='image',
                 gate_type='moe',
                 n_shared_experts=2,
                 n_routed_experts=64,
                 num_experts_per_tok=6,
                 norm_topk_prob=False,
                 aux_loss_alpha=0.001,
                 use_seq_aux=True,
                 use_dist_aux=False,
                 num_fcs=2,
                 qkv_bias=True,
                 ffn_type='origin',
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 layer_idx=0,
                 init_cfg=None) -> None:
        super(MoETransformerEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.ffn_type = ffn_type
        self.moe_type = moe_type
        self.n_shared_experts = n_shared_experts
        self.n_routed_experts = n_routed_experts
        self.num_experts_per_tok = num_experts_per_tok
        assert self.with_shared_experts or self.with_routed_experts

        self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)

        self.attn = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            layer_scale_init_value=layer_scale_init_value)

        self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

        if self.with_routed_experts:
            if gate_type == 'moe':
                self.gate = MoEGate(
                    hidden_size=embed_dims,
                    n_routed_experts=n_routed_experts,
                    num_experts_per_tok=num_experts_per_tok,
                    norm_topk_prob=norm_topk_prob,
                    aux_loss_alpha=aux_loss_alpha,
                    use_seq_aux=use_seq_aux,
                    use_dist_aux=use_dist_aux)
            else:
                raise NotImplementedError
            if ffn_type == 'origin':
                self.experts = nn.ModuleList([
                    FFN(embed_dims=embed_dims,
                        feedforward_channels=feedforward_channels,
                        num_fcs=num_fcs,
                        ffn_drop=drop_rate,
                        dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
                        act_cfg=act_cfg,
                        add_identity=False,
                        layer_scale_init_value=layer_scale_init_value)
                    for _ in range(n_routed_experts)])
            elif ffn_type == 'swiglu_fused':
                self.experts = nn.ModuleList([
                    SwiGLUFFNFused(
                        embed_dims=embed_dims,
                        feedforward_channels=feedforward_channels,
                        add_identity=False,
                        layer_scale_init_value=layer_scale_init_value)
                    for _ in range(n_routed_experts)])
            else:
                raise NotImplementedError
        if self.with_shared_experts:
            if ffn_type == 'origin':
                self.shared_experts = FFN(
                    embed_dims=embed_dims,
                    feedforward_channels=feedforward_channels * n_shared_experts,
                    num_fcs=num_fcs,
                    ffn_drop=drop_rate,
                    dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
                    act_cfg=act_cfg,
                    add_identity=False,
                    layer_scale_init_value=layer_scale_init_value)
            elif ffn_type == 'swiglu_fused':
                self.shared_experts = SwiGLUFFNFused(
                    embed_dims=embed_dims,
                    feedforward_channels=feedforward_channels * n_shared_experts,
                    add_identity=False,
                    layer_scale_init_value=layer_scale_init_value)
            else:
                raise NotImplementedError

    @property
    def norm1(self):
        return self.ln1

    @property
    def norm2(self):
        return self.ln2

    @property
    def with_shared_experts(self) -> bool:
        """Check if the model has shared experts."""
        return self.n_shared_experts is not None and self.n_shared_experts > 0

    @property
    def with_routed_experts(self) -> bool:
        """Check if the model has routed experts."""
        return self.n_routed_experts is not None and self.n_routed_experts > 0

    def init_weights(self):
        super(MoETransformerEncoderLayer, self).init_weights()
        if self.with_routed_experts:
            for m in self.experts.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.normal_(m.bias, std=1e-6)
        if self.with_shared_experts:
            for m in self.shared_experts.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.normal_(m.bias, std=1e-6)

    def ffn_convertor(self, state_dict, ffn_keys):
        if self.ffn_type == 'origin':
            weight_keys = sorted([key for key in ffn_keys if 'weight' in key and 'gamma2' not in key])
            bias_keys = sorted([key for key in ffn_keys if 'bias' in key])
            assert len(weight_keys) == len(bias_keys) == 2
            for i, (weight_key, bias_key) in enumerate(zip(weight_keys, bias_keys)):
                weight, bias = state_dict[weight_key], state_dict[bias_key]
                if weight.size(i) != self.feedforward_channels:
                    from mmengine.logging import MMLogger
                    logger = MMLogger.get_current_instance()
                    logger.info(
                        f'Resize the feedforward channels from {weight.size(i)} '
                        f'to {self.feedforward_channels}.')
                    weight = weight[:self.feedforward_channels, :] if i == 0 else weight[:, :self.feedforward_channels]
                    bias = bias[:self.feedforward_channels] if i == 0 else bias
                    state_dict[weight_key], state_dict[bias_key] = weight, bias
        return state_dict

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs) -> None:
        """Hack some keys of the model state dict."""
        ffn_prefix = prefix + 'ffn' if prefix else 'ffn'
        ffn_keys = [
            k for k in state_dict.keys() if k.startswith(ffn_prefix)
        ]
        expert_prefix = prefix + 'experts' if prefix else 'experts'
        expert_keys = [
            k for k in state_dict.keys() if k.startswith(expert_prefix)
        ]
        if len(expert_keys) == 0 and len(ffn_keys) != 0:
            state_dict = self.ffn_convertor(state_dict, ffn_keys)
            for ffn_key in ffn_keys:
                for i in range(self.n_routed_experts):
                    expert_key = expert_prefix + f'.{i}' + ffn_key[len(ffn_prefix):]
                    state_dict[expert_key] = state_dict[ffn_key]
                if self.with_shared_experts:
                    shared_expert_prefix = prefix + 'shared_experts' if prefix else 'shared_experts'
                    shared_expert_key = shared_expert_prefix + ffn_key[len(ffn_prefix):]
                    state_dict[shared_expert_key] = state_dict[ffn_key]
                state_dict.pop(ffn_key)
        super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,
                                      error_msgs)

    def forward(self, x: torch.Tensor):
        x = x + self.attn(self.ln1(x))
        identity = x
        x = self.ln2(x)
        # MoE
        aux_loss = None
        if self.with_routed_experts:
            if self.moe_type == 'token':
                topk_idx, topk_weight, aux_loss = self.gate(x)
            elif self.moe_type == 'image':
                topk_idx, topk_weight, aux_loss = self.gate(x[:, 0:1])  # cls_token
                topk_idx = topk_idx.repeat_interleave(x.shape[1], dim=0)
                topk_weight = topk_weight.repeat_interleave(x.shape[1], dim=0)
            else:
                raise NotImplementedError
            hidden_states = x.view(-1, x.shape[-1])
            flat_topk_idx = topk_idx.view(-1)
            if self.training:
                hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
                y = torch.empty_like(hidden_states)
                for i, expert in enumerate(self.experts):
                    y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]).type_as(y)
                y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
                y =  y.view(*x.shape)
                # y = AddAuxiliaryLoss.apply(y, aux_loss)
            else:
                y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*x.shape)
        if self.with_shared_experts:
            y = y + self.shared_experts(x) if self.with_routed_experts else self.shared_experts(x)
        return identity + y, aux_loss

    @torch.no_grad()
    def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
        token_idxs = floor_div(idxs, self.num_experts_per_tok)
        for i, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if i == 0 else tokens_per_expert[i-1]
            if start_idx == end_idx:
                continue
            expert = self.experts[i]
            exp_token_idx = token_idxs[start_idx:end_idx]
            expert_tokens = x[exp_token_idx]
            expert_out = expert(expert_tokens)
            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            if hasattr(torch, 'scatter_reduce'):
                expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
            else:
                expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
        return expert_cache


@MODELS.register_module()
class MoEVisionTransformer(VisionTransformer):
    """MoE Vision Transformer."""

    def __init__(self,
                 arch='base',
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 out_indices=-1,
                 drop_rate=0.,
                 drop_path_rate=0.,
                 qkv_bias=True,
                 norm_cfg=dict(type='LN', eps=1e-6),
                 final_norm=True,
                 out_type='cls_token',
                 with_cls_token=True,
                 frozen_stages=-1,
                 interpolate_mode='bicubic',
                 layer_scale_init_value=0.,
                 patch_cfg=dict(),
                 layer_cfgs=dict(),
                 pre_norm=False,
                 init_cfg=None):
        super(VisionTransformer, self).__init__(init_cfg)

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(self.arch_zoo), \
                f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {
                'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
            }
            assert isinstance(arch, dict) and essential_keys <= set(arch), \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = arch

        self.embed_dims = self.arch_settings['embed_dims']
        self.num_layers = self.arch_settings['num_layers']
        self.img_size = to_2tuple(img_size)

        # Set patch embedding
        _patch_cfg = dict(
            in_channels=in_channels,
            input_size=img_size,
            embed_dims=self.embed_dims,
            conv_type='Conv2d',
            kernel_size=patch_size,
            stride=patch_size,
            bias=not pre_norm,  # disable bias if pre_norm is used(e.g., CLIP)
        )
        _patch_cfg.update(patch_cfg)
        self.patch_embed = PatchEmbed(**_patch_cfg)
        self.patch_resolution = self.patch_embed.init_out_size
        num_patches = self.patch_resolution[0] * self.patch_resolution[1]

        # Set out type
        if out_type not in self.OUT_TYPES:
            raise ValueError(f'Unsupported `out_type` {out_type}, please '
                             f'choose from {self.OUT_TYPES}')
        self.out_type = out_type

        # Set cls token
        self.with_cls_token = with_cls_token
        if with_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
        elif out_type != 'cls_token':
            self.cls_token = None
            self.num_extra_tokens = 0
        else:
            raise ValueError(
                'with_cls_token must be True when `out_type="cls_token"`.')

        # Set position embedding
        self.interpolate_mode = interpolate_mode
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + self.num_extra_tokens,
                        self.embed_dims))
        self._register_load_state_dict_pre_hook(self._prepare_pos_embed)

        self.drop_after_pos = nn.Dropout(p=drop_rate)

        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), \
            f'"out_indices" must by a sequence or int, ' \
            f'get {type(out_indices)} instead.'
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = self.num_layers + index
            assert 0 <= out_indices[i] <= self.num_layers, \
                f'Invalid out_indices {index}'
        self.out_indices = out_indices

        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, self.num_layers)

        self.layers = ModuleList()
        if isinstance(layer_cfgs, dict):
            layer_cfgs = [layer_cfgs] * self.num_layers
        for i in range(self.num_layers):
            _layer_cfg = dict(
                embed_dims=self.embed_dims,
                num_heads=self.arch_settings['num_heads'],
                feedforward_channels=self.
                arch_settings['feedforward_channels'],
                layer_scale_init_value=layer_scale_init_value,
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                qkv_bias=qkv_bias,
                norm_cfg=norm_cfg,
                layer_idx=i)
            _layer_cfg.update(layer_cfgs[i])
            self.layers.append(MoETransformerEncoderLayer(**_layer_cfg))

        self.frozen_stages = frozen_stages
        if pre_norm:
            self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims)
        else:
            self.pre_norm = nn.Identity()

        self.final_norm = final_norm
        if final_norm:
            self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
        if self.out_type == 'avg_featmap':
            self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

        # freeze stages only when self.frozen_stages > 0
        if self.frozen_stages > 0:
            self._freeze_stages()

    def _freeze_stages(self):
        # freeze position embedding
        if self.pos_embed is not None:
            self.pos_embed.requires_grad = False
        # set dropout to eval model
        self.drop_after_pos.eval()
        # freeze patch embedding
        self.patch_embed.eval()
        for param in self.patch_embed.parameters():
            param.requires_grad = False
        # freeze pre-norm
        for param in self.pre_norm.parameters():
            param.requires_grad = False
        # freeze cls_token
        if self.cls_token is not None:
            self.cls_token.requires_grad = False
        # freeze layers
        for i in range(1, self.frozen_stages + 1):
            m = self.layers[i - 1]
            m.eval()
            for param in m.parameters():
                param.requires_grad = False
        # freeze the last layer norm
        if self.frozen_stages == len(self.layers):
            if self.final_norm:
                self.ln1.eval()
                for param in self.ln1.parameters():
                    param.requires_grad = False

            if self.out_type == 'avg_featmap':
                self.ln2.eval()
                for param in self.ln2.parameters():
                    param.requires_grad = False

    def forward(self, x):
        B = x.shape[0]
        x, patch_resolution = self.patch_embed(x)

        if self.cls_token is not None:
            # stole cls_tokens impl from Phil Wang, thanks
            cls_token = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_token, x), dim=1)

        x = x + resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=self.num_extra_tokens)
        x = self.drop_after_pos(x)

        x = self.pre_norm(x)

        outs = []
        losses = dict()
        for i, layer in enumerate(self.layers):
            x, loss = layer(x)

            if loss is not None:
                assert loss.numel() == 1
                losses[f'layer{i}.aux_loss'] = loss

            if i == len(self.layers) - 1 and self.final_norm:
                x = self.ln1(x)

            if i in self.out_indices:
                outs.append(self._format_output(x, patch_resolution))

        if len(losses) > 0:
            return tuple(outs), losses
        else:
            return tuple(outs)
