import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.insert(0, '../..')
import k_operation as kop

class CNN(nn.Module):
    def __init__(self, **kwargs):
        super(CNN, self).__init__()

        self.MODEL = kwargs["MODEL"]
        self.BATCH_SIZE = kwargs["BATCH_SIZE"]
        self.MAX_SENT_LEN = kwargs["MAX_SENT_LEN"]
        self.WORD_DIM = kwargs["WORD_DIM"]
        self.VOCAB_SIZE = kwargs["VOCAB_SIZE"]
        self.CLASS_SIZE = kwargs["CLASS_SIZE"]
        self.FILTERS = kwargs["FILTERS"]
        self.FILTER_NUM = kwargs["FILTER_NUM"]
        self.DROPOUT_PROB = kwargs["DROPOUT_PROB"]
        self.IN_CHANNEL = 1
        self.KOP = kwargs["KOP"]
        self.TIED = kwargs["TIED"]
        self.WARM_START = kwargs["WARM_START"]
        self.USE_FC = kwargs["FC"]

        assert (len(self.FILTERS) == len(self.FILTER_NUM))

        # one for UNK and one for zero padding
        self.embedding = nn.Embedding(self.VOCAB_SIZE + 2, self.WORD_DIM, padding_idx=self.VOCAB_SIZE + 1)
        if self.MODEL == "static" or self.MODEL == "non-static" or self.MODEL == "multichannel":
            self.WV_MATRIX = kwargs["WV_MATRIX"]
            self.embedding.weight.data.copy_(torch.from_numpy(self.WV_MATRIX))
            if self.MODEL == "static":
                self.embedding.weight.requires_grad = False
            elif self.MODEL == "multichannel":
                self.embedding2 = nn.Embedding(self.VOCAB_SIZE + 2, self.WORD_DIM, padding_idx=self.VOCAB_SIZE + 1)
                self.embedding2.weight.data.copy_(torch.from_numpy(self.WV_MATRIX))
                self.embedding2.weight.requires_grad = False
                self.IN_CHANNEL = 2

        for i in range(len(self.FILTERS)):
            if self.KOP:
                nblocks = 9
                if i == 0:
                    conv = kop.KOP1D(
                        in_size=self.MAX_SENT_LEN,
                        in_ch=self.IN_CHANNEL * self.WORD_DIM,
                        out_ch=self.FILTER_NUM[i],
                        kernel_size=self.FILTERS[i],
                        stride=1,
                        warm_start=self.WARM_START,
                        nblocks=nblocks)
                else:
                    if self.TIED:
                        conv = kop.KOP1D(
                            in_size=self.MAX_SENT_LEN,
                            in_ch=self.IN_CHANNEL * self.WORD_DIM,
                            out_ch=self.FILTER_NUM[i],
                            kernel_size=self.FILTERS[i],
                            stride=1,
                            warm_start=self.WARM_START,
                            nblocks=nblocks,
                            K1=self.conv_0.K1,
                            Kd=self.conv_0.Kd,
                            K2=self.conv_0.K2)
                    else:
                        conv = kop.KOP1D(
                            in_size=self.MAX_SENT_LEN,
                            in_ch=self.IN_CHANNEL * self.WORD_DIM,
                            out_ch=self.FILTER_NUM[i],
                            kernel_size=self.FILTERS[i],
                            stride=1,
                            warm_start=self.WARM_START,
                            nblocks=nblocks)
                #print(conv)
            elif self.USE_FC:
                self.fc1 = nn.Linear(
                    self.MAX_SENT_LEN * self.WORD_DIM, 
                    sum(self.FILTER_NUM))
            else:
                #conv = nn.Conv1d(self.IN_CHANNEL, self.FILTER_NUM[i], self.WORD_DIM * self.FILTERS[i], stride=self.WORD_DIM)
                conv = nn.Conv1d(self.IN_CHANNEL * self.WORD_DIM, 
                self.FILTER_NUM[i], self.FILTERS[i], stride=1)

            if not self.USE_FC:
                setattr(self, f'conv_{i}', conv)

        self.fc = nn.Linear(sum(self.FILTER_NUM), self.CLASS_SIZE)

    def load_params(self, model, requires_grad=True):
        def load_butterfly(p, p_new):
            p.twiddle.data.copy_(p_new.twiddle)
            p.twiddle.requires_grad = requires_grad

        load_butterfly(self.conv_0.K1, model.conv_0.K1)
        load_butterfly(self.conv_1.K1, model.conv_1.K1)
        load_butterfly(self.conv_2.K1, model.conv_2.K1)

        load_butterfly(self.conv_0.Kd, model.conv_0.Kd)
        load_butterfly(self.conv_1.Kd, model.conv_1.Kd)
        load_butterfly(self.conv_2.Kd, model.conv_2.Kd)

        load_butterfly(self.conv_0.K2, model.conv_0.K2)
        load_butterfly(self.conv_1.K2, model.conv_1.K2)
        load_butterfly(self.conv_2.K2, model.conv_2.K2)

    def get_conv(self, i):
        return getattr(self, f'conv_{i}')

    def forward(self, inp):
        x = self.embedding(inp)
        x = x.transpose(2, 1)
        #x = x.view(-1, 1, self.WORD_DIM * self.MAX_SENT_LEN)
        if self.MODEL == "multichannel":
            x2 = self.embedding2(inp).view(-1, 1, self.WORD_DIM * self.MAX_SENT_LEN)
            x = torch.cat((x, x2), 1)

        if self.USE_FC:
            x = x.reshape(-1, self.WORD_DIM * self.MAX_SENT_LEN)
            #x.squeeze_(1)
            x = self.fc1(x)
        else:
            conv_results = [
                F.max_pool1d(F.relu(self.get_conv(i)(x)), self.MAX_SENT_LEN - self.FILTERS[i] + 1)
                    .view(-1, self.FILTER_NUM[i])
                for i in range(len(self.FILTERS))]

            x = torch.cat(conv_results, 1)

        x = F.dropout(x, p=self.DROPOUT_PROB, training=self.training)
        x = self.fc(x)

        return x
