import torch.nn as nn
import torch
import numpy as np

from .gpt import GPT


class GPTLM(nn.Module):
    """
    GPT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, gpt=None, args=None):
        """
        :param gpt: model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.gpt = gpt
        self.mask_lm = TBTMaskedSeqModel(args)

    def forward(self, data, train):
        x = self.gpt(data)
        output = self.mask_lm(x, data, train)
        return output

class TBTMaskedSeqModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, args):
        """
        :param hidden: output size of GPT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.args = args

        self.info_onehot_cnt = args['info_element_cnt'] + 1
        self.task_name = ['trigger', 'action', 'info', 'voiceTimes']

        self.num_experts = 3
        self.experts_shared = nn.ModuleList([Encoder(args) for _ in range(self.num_experts)])

        self.gate_specific = nn.ModuleDict(
            {task: nn.Sequential(nn.Linear(self.experts_shared[0].input_layer.in_features, self.num_experts),
                                 nn.Softmax(dim=-1)) for task in self.task_name})

        num_out_dict = {'trigger_info': 1, 'trigger_action': 1, 'action': args["action_element_cnt"], 'info': args["info_element_cnt"] + 1,
                        'voiceTimes': args["voiceTimes_cnt"]}
        self.decoders = nn.ModuleDict({task: Decoder(args['hidden'], args['decoder_mlp_hid_dim'],
                                                     args['decoder_mlp_layer_num'], num_out_dict[task],
                                                     args['decoder_drop_out']) for task in
                                       num_out_dict.keys()}) 
                                       
        self.depend_tasks = ['info']
        self.importance_task = ['action', 'trigger']

        input_dim = self.decoders['action'].project[0].in_features
        action_input_dim = input_dim
        trigger_input_dim = action_input_dim 

        self.decoders['action'].project[0] = nn.Linear(action_input_dim,
                                                       self.decoders['action'].project[0].out_features)
        self.decoders['trigger_info'].project[0] = nn.Linear(trigger_input_dim,
                                                        self.decoders['trigger_info'].project[0].out_features)
        self.decoders['trigger_action'].project[0] = nn.Linear(trigger_input_dim,
                                                        self.decoders['trigger_action'].project[0].out_features)

    def forward(self, x, data, train=True):
        action = data['action']
        info = data['info']

        x, x_ori = x
        experts_shared_rep = torch.stack([e(x) for e in self.experts_shared])
        gate_rep = {}
        for task in self.task_name:
            selector = self.gate_specific[task](x)
            gate_rep[task] = torch.einsum('ijk..., jki -> jk...', experts_shared_rep, selector)

        out = {}
        for task in ['action', 'info', 'voiceTimes']:
            expert_out = torch.cat([gate_rep[task]], dim=-1)
            out[task] = self.decoders[task](expert_out)

        if train:
            info = self.one_hot(info.long(), self.info_onehot_cnt)
            trigger_out = torch.cat([gate_rep['trigger']], dim=-1)
            trigger_info = self.decoders['trigger_info'](trigger_out)
            trigger_action = self.decoders['trigger_action'](trigger_out)
            info_mask = (data['info']==25)
            out['trigger'] = trigger_action.where(info_mask.unsqueeze(-1), trigger_info)

        else:
            trigger_out = torch.cat([gate_rep['trigger']], dim=-1)
            trigger_info = self.decoders['trigger_info'](trigger_out)
            trigger_action = self.decoders['trigger_action'](trigger_out)

            info_mask = (self.process_info_pred(out['info'])==25)
            out['trigger'] = trigger_action.where(info_mask.unsqueeze(-1), trigger_info)

        return {
            "trigger": out['trigger'],
            "action": out['action'],
            "info": out['info'],
            "voiceTimes": out['voiceTimes']
        }
    
    def one_hot(self, labels, num):
        batch_size = labels.size(0)
        seq_len = labels.size(1)
        onehot = torch.LongTensor(np.eye(num)[labels.reshape(-1).to('cpu')]).reshape(batch_size, seq_len, -1).to(self.args['device'])
        return onehot

    def process_action_pred(self, pred_action):
        pred_action = torch.sigmoid(pred_action)
        pred_action = (pred_action > 0.5).long()
        return pred_action

    def process_info_pred(self, pred_info):
        pred_info = torch.softmax(pred_info, dim=-1)
        pred_info = pred_info.max(-1)[1].long()
        return pred_info

class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        input_size = args['hidden']
        # input_size = args['embeded_all_feature_cnt']
        self.hidden_dim = args['mlp_hid_dim']
        self.layer_num = args['mlp_layer_num']

        self.input_layer = nn.Linear(input_size, self.hidden_dim)
        self.relu = torch.nn.LeakyReLU()
        layers = []
        for i in range(self.layer_num):
            mid_layer = []
            mid_layer.append(nn.Linear(self.hidden_dim, self.hidden_dim))
            mid_layer.append(torch.nn.LeakyReLU())
            layers.append(nn.Sequential(*mid_layer))
        self.hidden_layer = nn.Sequential(*layers)

    def forward(self, total_feature):
        out = self.input_layer(total_feature)
        out = self.relu(out)
        out = self.hidden_layer(out)
        return out

class Decoder(nn.Module):
    def __init__(self, input_dim, hid_dim, layer_num, out_num, drop):
        super(Decoder, self).__init__()
        self.input_hidden_dim = input_dim
        self.hidden_dim = hid_dim
        self.layer_num = layer_num
        self.dropout = drop
        input_layer = nn.Linear(self.input_hidden_dim, self.hidden_dim)
        self.mid_layers = []
        for _ in range(self.layer_num):
            self.mid_layers.append(nn.Linear(self.hidden_dim, self.hidden_dim))
            self.mid_layers.append(torch.nn.LeakyReLU())

        self.project = nn.Sequential(
            input_layer,
            torch.nn.LeakyReLU(),
            *self.mid_layers,
            nn.Dropout(p=self.dropout, inplace=False),
            nn.Linear(self.hidden_dim, out_num))

    def forward(self, x):
        return self.project(x)


class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: model output size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))


class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, args):
        """
        :param hidden: output size of model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))
