import copy
from torch.nn import Module, Linear, init
from divmorph.config import cfg
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import ModuleList

from .svd_module import SVDMultiheadAttention, SVDLinearOptimized


def build_mlp(hidden_size, projector_dim, z_dim):
    return nn.Sequential(
                nn.Linear(hidden_size, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, z_dim),
            )

class Gate(nn.Module):
    def __init__(self, hidden_dim, num_experts, other_size, bias_lr=0.001, use_auxfreeloss=True, aux_loss=False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.use_auxfreeloss = use_auxfreeloss
        if other_size % num_experts != 0:
            raise ValueError(f"other_size ({other_size}) must be divisible by num_experts ({num_experts})")
        self.size_per_expert = other_size // num_experts

        self.weight = nn.Parameter(torch.empty(num_experts, hidden_dim))
        self.softmax = nn.Softmax(dim=-1)

        if self.use_auxfreeloss:
            self.register_buffer('expert_bias', torch.zeros(num_experts))
            self.bias_lr = bias_lr

        self.reset_parameters()

    def reset_parameters(self):
        init.zeros_(self.weight)

    def get_topk(self, x, k):
        bs = x.shape[1]
        value = x + self.expert_bias if self.use_auxfreeloss else x

        _, indices = torch.topk(value, k, dim=-1)
        out = torch.full_like(x, float('-inf'))
        out.scatter_(dim=-1, index=indices, src=x.gather(dim=-1, index=indices))

        if cfg.PPO.LOAD_MODE == 'TRANSFER':
            return out, indices

        training_condition = self.use_auxfreeloss and self.training and bs == cfg.PPO.BATCH_SIZE
        distill_condition = self.use_auxfreeloss and self.training and cfg.DISTILL.IS_DISTILL
        if training_condition or distill_condition:
            count = torch.bincount(indices.view(-1), minlength=self.num_experts).to(torch.float32)
            update = count.mean() - count
            update = torch.sign(update)
            self.expert_bias = self.expert_bias + self.bias_lr * update
        
        return out, indices


    def forward(self, x, ids, k):
        limb_num = x.shape[0]
        x = F.linear(x, self.weight)

        dtype = x.dtype
        
        if dtype != torch.float32:
            x = x.to(torch.float32)
        x = x.mean(dim=0, keepdim=True)
        x = x.repeat(limb_num, 1, 1)

        scores, indices = self.get_topk(x, k)

        scores = self.softmax(scores)
        scores = scores.to(dtype)
        scores = k * scores

        scores = scores.repeat_interleave(self.size_per_expert, dim=-1)
        
        return scores, indices
    
    
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu


class TransformerEncoder(nn.Module):
    __constants__ = ["norm"]

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()

        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None, context=None, morphology_info=None, gate1=None, gate2=None):
        output = src

        for l in self.layers:
            output = l(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, context=context, morphology_info=morphology_info, gate1=gate1, gate2=gate2)

        if self.norm is not None:
            output = self.norm(output)

        return output


    def get_attention_maps(self, src, mask=None, src_key_padding_mask=None, context=None, morphology_info=None, gate1=None, gate2=None):
        attention_maps = []
        output = src

        for l in self.layers:
            output, attention_map = l(
                output,
                src_mask=mask,
                src_key_padding_mask=src_key_padding_mask,
                return_attention=True, 
                context=context, 
                morphology_info=morphology_info,
                gate1=gate1,
                gate2=gate2
            )
            attention_maps.append(attention_map)

        if self.norm is not None:
            output = self.norm(output)

        return output, attention_maps


class TransformerEncoderLayerResidual(nn.Module):
    def __init__(
        self, d_model, nhead, gene_size, other_size_1, other_size_2, dim_feedforward=2048, dropout=0.1, activation="relu"
    ):
        super(TransformerEncoderLayerResidual, self).__init__()

        self.self_attn = SVDMultiheadAttention(d_model, nhead, gene_size, other_size_1, other_size_2, dropout=dropout)

        self.linear1 = SVDLinearOptimized(d_model, dim_feedforward, gene_size, other_size_1, other_size_2)
        self.dropout = nn.Dropout(dropout)

        self.linear2 = SVDLinearOptimized(dim_feedforward, d_model, gene_size, other_size_1, other_size_2)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        if cfg.MODEL.TRANSFORMER.FIX_ATTENTION:
            self.norm_context = nn.LayerNorm(d_model)
        
        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if "activation" not in state:
            state["activation"] = F.relu
        super(TransformerEncoderLayerResidual, self).__setstate__(state)

    def forward(self, src, gate1, gate2, src_mask=None, src_key_padding_mask=None, return_attention=False, context=None, morphology_info=None):
        src2 = self.norm1(src)

        if context is not None:
            context_normed = self.norm_context(context)
            src2, attn_weights = self.self_attn(
                context_normed, context_normed, src2, src_mask=src_mask, mask=src_key_padding_mask, gate1=gate1, gate2=gate2
            )
        else:
            src2, attn_weights = self.self_attn(
                src2, src2, src2, src_mask=src_mask, mask=src_key_padding_mask, gate1=gate1, gate2=gate2
            )
        
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2, gate1=gate1, gate2=gate2))), gate1=gate1, gate2=gate2)
        src = src + self.dropout2(src2)

        if return_attention:
            return src, attn_weights
        else:
            return src
        
    