import torch
import torch.nn as nn


class Selector(nn.Module):
    def __init__(self, selector_type, *args, **kwargs):
        super().__init__()

        if selector_type == "multimodalmlp":
            self.module = MultimodalMLP(*args, **kwargs)
        elif selector_type == "totalmlp":
            self.module = MultimodalMLP(combine=True, *args, **kwargs)
        elif selector_type == "linear":
            self.module = LinearLayer(*args, **kwargs)
        elif selector_type == "mlps":
            self.module = MLPS(*args, **kwargs)
        elif selector_type == "totalmlp_att":
            self.module = AttentionModel(*args, **kwargs)
        elif selector_type == "score_att":
            self.module = AttentionScore(*args, **kwargs)
        elif selector_type == "feat_att":
            self.module = AttentionFeature(*args, **kwargs)
        else:
            raise NotImplementedError("Unknown selector type: {}".format(selector_type))

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


class MultimodalMLP(nn.Module):
    def __init__(self, feat_size, layer_num, dimens, combine=False, average=1, concat=0, **kwargs):
        super(MultimodalMLP, self).__init__()

        ins_embed_size = feat_size['ins_embed_size']
        n_instructions = 4
        self.concat = concat
        self.use_softmax = False
        self.use_qi_embed = False
        self.combine = combine

        image_feat_size  = feat_size['image_feat_size']
        text_feat_size   = feat_size['text_feat_size']
        image_embed_size = feat_size['image_feat_size']
        text_embed_size  = feat_size['text_feat_size']

        if self.use_qi_embed:
            qi_feat_size   = feat_size['text_feat_size']
            qi_embed_size = feat_size['text_feat_size']
            input_size = image_embed_size + text_embed_size + ins_embed_size + qi_embed_size
        elif self.combine:
            input_size = image_embed_size + text_embed_size + ins_embed_size
        else:
            input_size = image_embed_size + text_embed_size

        if self.concat==1:
            input_size = ins_embed_size
        self.pool_image_feats = False
        self.pool_text_feats = False
        self.pool_image_dim = 1
        self.pool_text_dim = 1
        self.pool_type = None
        if self.pool_image_feats or self.pool_text_feats:
            assert self.pool_type is not None
        # batchnorm size : 32(ave) 114(noav)
        use_batchnorm = True
        if concat==0:
            batchnorm_size = dimens if average==0 else 114
        else:
            batchnorm_size = 3 if average==0 else 114*3
        if use_batchnorm:
            if layer_num == 2:
                self.selective_predictor = nn.Sequential(
                    nn.Linear(input_size, dimens),
                    nn.Dropout(p=0.1),
                    nn.BatchNorm1d(batchnorm_size),
                    nn.ReLU(),
                    nn.Linear(dimens, dimens),
                    nn.Dropout(p=0.1),
                    nn.BatchNorm1d(batchnorm_size),
                    nn.ReLU(),
                    nn.Linear(dimens, 1),
                )
            elif layer_num == 1:
                self.selective_predictor = nn.Sequential(
                    nn.Linear(input_size, dimens),
                    nn.Dropout(p=0.1),
                    nn.BatchNorm1d(batchnorm_size),
                    nn.ReLU(),
                    nn.Linear(dimens, 1),
                )
        else:
            if layer_num == 2:
                self.selective_predictor = nn.Sequential(
                    nn.Linear(input_size, dimens),
                    nn.Dropout(p=0.1),
                    nn.ReLU(),
                    nn.Linear(dimens, dimens),
                    nn.Dropout(p=0.1),
                    nn.ReLU(),
                    nn.Linear(dimens, 1),
                )
            elif layer_num == 1:
                self.selective_predictor = nn.Sequential(
                    nn.Linear(input_size, dimens),
                    nn.Dropout(p=0.1),
                    nn.ReLU(),
                    nn.Linear(dimens, 1),
                )

        if self.use_softmax:
            self.s_embed = nn.Sequential(
                nn.Softmax(dim=-1), nn.Linear(n_instructions, ins_embed_size), nn.ReLU()
            )
        else:
            self.s_embed = nn.Sequential(
                nn.ReLU(), nn.Linear(n_instructions, ins_embed_size), nn.ReLU()
            )

        # Text & image embedding layers
        self.init_embedding_layers(
            False,
            use_batchnorm,
            image_feat_size,
            image_embed_size,
            text_feat_size,
            text_embed_size
        )

        # Fused text+image feature layer
        if self.use_qi_embed:
            self.qi_embed = nn.Sequential(
                  nn.ReLU(), nn.Linear(qi_feat_size, qi_embed_size), nn.ReLU()
            )

    def init_embedding_layers(
            self,
            double_embedding_layers,
            use_batchnorm,
            image_feat_size,
            image_embed_size,
            text_feat_size,
            text_embed_size
    ):
        if double_embedding_layers:
            if use_batchnorm:
                self.image_embed = nn.Sequential(
                    nn.ReLU(), nn.BatchNorm1d(image_feat_size),
                    nn.Linear(image_feat_size, image_embed_size), nn.ReLU(),
                    nn.Linear(image_embed_size, image_embed_size), nn.ReLU()
                )
                self.text_embed = nn.Sequential(
                    nn.ReLU(), nn.BatchNorm1d(text_feat_size),
                    nn.Linear(text_feat_size, text_embed_size), nn.ReLU(),
                    nn.Linear(text_embed_size, text_embed_size), nn.ReLU(),
                )
            else:
                self.image_embed = nn.Sequential(
                    nn.ReLU(),
                    nn.Linear(image_feat_size, image_embed_size), nn.ReLU(),
                    nn.Linear(image_embed_size, image_embed_size), nn.ReLU()
                )
                self.text_embed = nn.Sequential(
                    nn.ReLU(),
                    nn.Linear(text_feat_size, text_embed_size), nn.ReLU(),
                    nn.Linear(text_embed_size, text_embed_size), nn.ReLU(),
                )
        else:
            if use_batchnorm:
                self.image_embed = nn.Sequential(
                    nn.ReLU(), nn.BatchNorm1d(image_feat_size),
                    nn.Linear(image_feat_size, image_embed_size), nn.ReLU()
                )
                self.text_embed = nn.Sequential(
                    nn.ReLU(), nn.BatchNorm1d(text_feat_size),
                    nn.Linear(text_feat_size, text_embed_size), nn.ReLU()
                )
            else:
                self.image_embed = nn.Sequential(
                    nn.ReLU(),
                    nn.Linear(image_feat_size, image_embed_size), nn.ReLU()
                )
                self.text_embed = nn.Sequential(
                    nn.ReLU(),
                    nn.Linear(text_feat_size, text_embed_size), nn.ReLU()
                )

    def pool_features(self, features, pool_dim, pool_type):
        if pool_type == 'max':
            pooled_features = features.max(pool_dim).values
        elif pool_type == 'mean':
            pooled_features = features.mean(pool_dim)
        else:
            raise Exception(f'Pool type {pool_type} not recognized.')
        return pooled_features

    def forward(self, instruction_logits, image_feats, text_feats, qi_embed=None, **kwargs):
        if self.pool_image_feats:
            image_feats = self.pool_features(
                image_feats, self.pool_image_dim, self.pool_type
            )
        if self.pool_text_feats:
            text_feats = self.pool_features(
                text_feats, self.pool_text_dim, self.pool_type
            )

        #image_emb =  self.image_embed(image_feats) # relu, linear, relu
        #text_emb =  self.text_embed(text_feats)
        #instruction_emb = self.s_embed(instruction_logits)
        if self.concat==0:
            if self.use_qi_embed:
                qi_emb =  self.qi_embed(qi_embed)
                input_feat = torch.cat([image_feats, text_feats, qi_emb, instruction_logits], -1)
            elif self.combine:
                input_feat = torch.cat([image_feats, text_feats, instruction_logits], -1)
            else:
                input_feat = torch.cat([image_feats, text_feats], -1)
        else:
            if len(image_feats.shape) == 2:
                image_feats = image_feats.unsqueeze(1)
                text_feats = text_feats.unsqueeze(1)
                instruction_logits = instruction_logits.unsqueeze(1)
                input_feat = torch.cat((image_feats, text_feats, instruction_logits), dim=1)
            else:
                input_feat = torch.cat((image_feats, text_feats, instruction_logits), dim=1)
        #print(input_feat.shape)
        return {"scores": self.selective_predictor(input_feat)} # mlp

class AttentionModel(nn.Module):
    def __init__(self, feat_size, nhead, layer_num, residual=0, concat=0, **kwargs):
        super(AttentionModel, self).__init__()
        self.concat = concat
        if self.concat==0:
            d_model = feat_size['ins_embed_size'] + feat_size['image_feat_size'] + feat_size['text_feat_size']
        else:
            d_model = feat_size['ins_embed_size']
        self.multihead_attns = nn.ModuleList([
            nn.MultiheadAttention(d_model, nhead) for _ in range(layer_num)
        ])
        self.linears = nn.ModuleList([
            nn.Linear(d_model, d_model) for _ in range(layer_num)
        ])
        self.linear = nn.Linear(d_model, 1) 
        self.num_layers = layer_num
        self.residual = residual

    def forward(self, instruction_logits, image_feats, text_feats):
        if self.concat==0:
            x = torch.cat([image_feats, text_feats, instruction_logits], -1)
        else:
            if len(image_feats.shape) == 2:
                image_feats = image_feats.unsqueeze(1)
                text_feats = text_feats.unsqueeze(1)
                instruction_logits = instruction_logits.unsqueeze(1)
                x = torch.cat((image_feats, text_feats, instruction_logits), dim=1)
            else:
                x = torch.cat((image_feats, text_feats, instruction_logits), dim=1)
        for layer in range(self.num_layers):
            attn_output, _ = self.multihead_attns[layer](x, x, x)  # Self-attention
            x = x + attn_output if self.residual==0 else x
            x = self.linears[layer](x) if self.num_layers>1 else x

        return {"scores": self.linear(x)}

class LinearLayer(nn.Module):
    def __init__(self, feat_size, **kwargs):
        super(LinearLayer, self).__init__()
        d_model = feat_size['ins_embed_size'] + feat_size['image_feat_size'] + feat_size['text_feat_size']
        self.linear = nn.Linear(d_model, 1) 

    def forward(self, instruction_logits, image_feats, text_feats):
        x = torch.cat([image_feats, text_feats, instruction_logits], -1)
        return {'scores': self.linear(x)}
    
class MLPS(nn.Module):
    def __init__(self, feat_size, layer_num, dimens, **kwargs):
        super(MLPS, self).__init__()
        n_instructions = 4
        if layer_num == 2:
            self.mlp = nn.Sequential(
                nn.Linear(n_instructions, dimens), 
                nn.Dropout(p=0.1),
                nn.ReLU(),  #  nn.Tanh() or nn.Sigmoid()
                nn.Linear(dimens, dimens),
                nn.Dropout(p=0.1),
                nn.ReLU(),
                nn.Linear(dimens, 1)  # Output layer
            )
        elif layer_num == 1:
            self.mlp = nn.Sequential(
                nn.Linear(n_instructions, dimens),  
                nn.Dropout(p=0.1),
                nn.ReLU(),  #  nn.Tanh() or nn.Sigmoid()
                nn.Linear(dimens, 1)  # Output layer
            )

    def forward(self, instruction_logits, image_feats=None, text_feats=None, qi_embed=None, **kwargs):
        logits = self.mlp(instruction_logits)
        return {'scores': logits}

class AttentionScore(nn.Module):
    def __init__(self, feat_size, nhead, layer_num, residual=0, concat=0, **kwargs):
        super(AttentionScore, self).__init__()
        self.concat = concat
       
        d_model = feat_size['ins_embed_size']
        self.multihead_attns = nn.ModuleList([
            nn.MultiheadAttention(d_model, nhead) for _ in range(layer_num)
        ])
        self.linears = nn.ModuleList([
            nn.Linear(d_model, d_model) for _ in range(layer_num)
        ])
        self.linear = nn.Linear(d_model, 1) 
        self.num_layers = layer_num
        self.residual = residual

    def forward(self, instruction_logits, image_feats=None, text_feats=None):
        x = instruction_logits
        for layer in range(self.num_layers):
            attn_output, _ = self.multihead_attns[layer](x, x, x)  # Self-attention
            x = x + attn_output if self.residual==0 else x
            x = self.linears[layer](x) if self.num_layers>1 else x

        return {"scores": self.linear(x)}
    
class AttentionFeature(nn.Module):
    def __init__(self, feat_size, nhead, layer_num, residual=0, concat=0, **kwargs):
        super(AttentionFeature, self).__init__()
        self.concat = concat
       
        d_model = feat_size['image_feat_size'] + feat_size['text_feat_size']
        self.multihead_attns = nn.ModuleList([
            nn.MultiheadAttention(d_model, nhead) for _ in range(layer_num)
        ])
        self.linears = nn.ModuleList([
            nn.Linear(d_model, d_model) for _ in range(layer_num)
        ])
        self.linear = nn.Linear(d_model, 1) 
        self.num_layers = layer_num
        self.residual = residual

    def forward(self, instruction_logits, image_feats, text_feats, qi_embed=None, **kwargs):
        x = torch.cat([image_feats, text_feats], -1)
        for layer in range(self.num_layers):
            attn_output, _ = self.multihead_attns[layer](x, x, x)  # Self-attention
            x = x + attn_output if self.residual==0 else x
            x = self.linears[layer](x) if self.num_layers>1 else x

        return {"scores": self.linear(x)}