import torch as T
import torch.nn as nn
import torch.nn.functional as F
from torchcrf import CRF

class seq_label_framework(nn.Module):
    def __init__(self, data, config):

        super(seq_label_framework, self).__init__()

        self.config = config
        self.classes_num = data["classes_num"]
        embedding_data = data["embeddings"]
        self.pad_id = data["PAD_id"]
        self.unk_id = data["UNK_id"]

        self.out_dropout = config["out_dropout"]
        self.in_dropout = config["in_dropout"]
        self.hidden_size = config["hidden_size"]
        self.unk_embed = None

        self.ATT_PAD = T.tensor(-999999).float()
        self.zeros = T.tensor(0.0).float()

        if embedding_data is not None:
            embedding_data = T.tensor(embedding_data)
            self.unk_embed = nn.Parameter(T.randn(embedding_data.size(-1)))
            self.word_embedding = nn.Embedding.from_pretrained(embedding_data,
                                                               freeze=config["word_embd_freeze"],
                                                               padding_idx=self.pad_id)
        else:
            vocab_len = data["vocab_len"]
            self.word_embedding = nn.Embedding(vocab_len, config["embd_dim"],
                                               padding_idx=self.pad_id)

        if config["use_char_feats"]:
            charvocab_len = data["charvocab_len"]
            self.char_embedding = nn.Embedding(charvocab_len, config["char_embed_dim"],
                                               padding_idx=data["char_pad_id"])


            self.char_conv = nn.Conv1d(in_channels=config["char_embed_dim"],
                                        out_channels=config["char_out_dim"],
                                        kernel_size=config["char_window_size"],
                                        stride=1,
                                        padding=config["char_window_size"]//2,
                                        dilation=1,
                                        groups=1,
                                        bias=True,
                                        padding_mode='zeros')


        self.embd_dim = self.word_embedding.weight.size(-1)
        if self.config["initial_transform"]:
            if "input_size" in self.config:
                x = config["input_size"]
            else:
                x = config["hidden_size"]
            self.transform_linear = nn.Linear(self.embd_dim, x)

        if self.config["post_cat_transform"]:
            if "input_size" in self.config:
                x = config["input_size"]
            else:
                x = self.embd_dim
            self.transform_linear2 = nn.Linear(x, self.hidden_size)

        encoder_fn = eval(config["encoder_type"])
        self.encoder = encoder_fn(config)

        self.prediction_linear1 = nn.Linear(config["hidden_size"], config["hidden_size"])
        self.prediction_linear2 = nn.Linear(config["hidden_size"], self.classes_num)

        self.crf = CRF(self.config["classes_num"], batch_first=True)


    # %%
    def embed(self, sequence_idx, input_mask):

        N, S = sequence_idx.size()

        sequence = self.word_embedding(sequence_idx)

        if self.unk_id is not None and self.unk_embed is not None:
            sequence = T.where(sequence_idx.unsqueeze(-1) == self.unk_id,
                               self.unk_embed.view(1, 1, -1).repeat(N, S, 1),
                               sequence)

        assert sequence.size() == (N, S, self.embd_dim)

        if self.config["initial_transform"]:
            sequence = self.transform_linear(sequence)
        sequence = sequence * input_mask.view(N, S, 1)

        return sequence, input_mask

    def char_embed(self, char_sequence_idx):

        N, S, C = char_sequence_idx.size()
        char_sequence = self.char_embedding(char_sequence_idx)
        N, S, C, D = char_sequence.size()
        char_sequence = char_sequence.view(N*S, C, D).permute(0, 2, 1).contiguous()
        char_sequence = self.char_conv(char_sequence)
        D = char_sequence.size(-2)
        assert char_sequence.size() == (N*S, D, C)
        char_sequence = T.max(char_sequence, dim=-1)[0]
        assert char_sequence.size() == (N*S, D)

        char_sequence = char_sequence.view(N, S, D)

        return char_sequence


    # %%
    def forward(self, batch):

        sequence = batch["sequences_vec"]

        feats = batch["feats"]
        input_mask = batch["input_masks"]
        labels = batch["labels_vec"]

        N, S = sequence.size()

        # EMBEDDING BLOCK
        sequence, input_mask = self.embed(sequence, input_mask)

        sequence = F.dropout(sequence, p=self.in_dropout, training=self.training)

        N, S, D = sequence.size()
        if self.config["use_feats"]:
            N, S, v = feats.size()
            sequence = T.cat([sequence, feats], dim=-1)
            assert sequence.size() == (N, S, D + v)

        if self.config["use_char_feats"]:
            char_sequence = batch["char_sequences_vec"]
            char_sequence = self.char_embed(char_sequence)
            sequence = T.cat([sequence, char_sequence], dim=-1)

        if self.config["post_cat_transform"]:
            sequence = self.transform_linear2(sequence)

        # ENCODER BLOCK
        sequence_dict = self.encoder(sequence, input_mask)

        sequence = sequence_dict["sequence"]
        #input_mask = sequence_dict["input_mask"]

        aux_loss = None
        if "aux_loss" in sequence_dict:
            aux_loss = sequence_dict["aux_loss"].mean()

        sequence = F.dropout(sequence, p=self.out_dropout, training=self.training)
        intermediate = F.gelu(self.prediction_linear1(sequence))
        #intermediate = F.dropout(intermediate, p=self.out_dropout, training=self.training)
        logits = self.prediction_linear2(intermediate)

        assert logits.size() == (N, S, self.classes_num)

        loss = self.crf(logits, labels, mask=input_mask.bool())

        if aux_loss is not None and self.training:
            loss = loss + aux_loss

        predictions = self.crf.decode(logits, mask=input_mask.bool())

        return {"loss": -loss, "predictions": predictions}
