
import copy
import torch
from torch import nn
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def apply_residual(hidden_states, residual):
    hidden_states = [residual[i] + hid for i, hid in enumerate(hidden_states)]
    return hidden_states

def get_embeds_from_ids(input_ids, valid_pos, pad_ids, processors):
    inputs_embeds = []
    for i, mod_valid_pos in enumerate(valid_pos):
        mot_input_id = input_ids.clone()
        mot_input_id = mot_input_id.masked_fill_(~mod_valid_pos, pad_ids[i])
        mod_inputs_embeds = processors[i](mot_input_id)
        inputs_embeds.append(mod_inputs_embeds)
    return inputs_embeds

    
class MoTBase(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def update_typeids(self, type_ids):
        if type_ids is not None:
            self.valid_pos = type_ids

    def forward(self, hidden_states):
        hidden_states = self.apply_module(hidden_states, self.fn)
        return hidden_states
    
    def apply_module(self, hidden_states: torch.Tensor, module: torch.Tensor):
        hidden_states_ = []
        for type_id, hid in enumerate(hidden_states):
            hidden_states_.append(module[type_id](hid))
        return hidden_states_

class MoTLayerNorm(MoTBase):
    def __init__(self, norm_fn, modality_num=2, out_dim=None):
        super().__init__()
        self.fn = _get_clones(norm_fn, modality_num)
        self.modality_num = modality_num
        self.type_ids = None
        self.valid_pos = None

    def forward(self, hidden_states):
        pass

class MoTMod(MoTBase):
    def __init__(self, norm_fn, modality_num=2, out_dim=None):
        super().__init__()
        self.fn = _get_clones(norm_fn, modality_num)
        self.modality_num = modality_num
        if isinstance(norm_fn, (nn.Linear, nn.Embedding)):
            out_dim = norm_fn.out_features
        self.out_dim = out_dim
        self.type_ids = None
        self.valid_pos = None

class MoTEmbed(MoTBase):
    def __init__(self, fn_list, modality_num=2, out_dims=None):
        super().__init__()
        self.fn = nn.ModuleList(fn_list)
        self.modality_num = modality_num
        if out_dims is None:
            out_dims = []
            for fn in fn_list:
                if isinstance(fn, (nn.Linear)):
                    out_dims.append(fn.out_features)
                elif isinstance(fn, (nn.Embedding)):
                    out_dims.append(fn.embedding_dim)
                elif hasattr(fn, 'output_dim'):
                    out_dims.append(fn.output_dim)
                else:
                    raise ValueError('out_dims is None with not Linear or Embedding term in fn_list')
                
        assert (len(out_dims) == modality_num) and (len(fn_list) == modality_num)
        self.out_dims = out_dims
        self.out_dim = max(out_dims)
        self.type_ids = None
        self.valid_pos = None
        
    def apply_module(self, hidden_states: torch.Tensor, module: torch.Tensor):
        pass

class MoTDiffFuncMod(MoTBase):
    def __init__(self, fn_list, modality_num=2, out_dims=None):
        super().__init__()
        self.fn = nn.ModuleList(fn_list)
        self.modality_num = modality_num
        
        self.out_dims = out_dims
        self.out_dim = max(out_dims)
        self.type_ids = None
        self.valid_pos = None

    def update_typeids(self, type_ids):
        if type_ids is not None:
            self.valid_pos = type_ids

    def apply_module(self, hidden_states: torch.Tensor, module: torch.Tensor):
        valid_pos = self.valid_pos
        hidden_states_ = []
        for type_id in range(self.modality_num):
            mod_valid_pos = valid_pos[type_id]
            out = module[type_id](hidden_states[mod_valid_pos])
            hidden_states_.append(out)
        return hidden_states_