from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_butterfly.complex_utils import view_as_real, view_as_complex
import numpy as np
import sys
sys.path.insert(0, '../..')
import k_operation as kop

class ButterfLeNet1D(nn.Module):
    def __init__(self, **kwargs):
        super(ButterfLeNet1D, self).__init__()

        self.MODEL = kwargs["MODEL"]
        self.BATCH_SIZE = kwargs["BATCH_SIZE"]
        self.MAX_SENT_LEN = kwargs["MAX_SENT_LEN"]
        self.WORD_DIM = kwargs["WORD_DIM"]
        self.VOCAB_SIZE = kwargs["VOCAB_SIZE"]
        self.CLASS_SIZE = kwargs["CLASS_SIZE"]
        self.FILTERS = kwargs["FILTERS"]
        self.FILTER_NUM = kwargs["FILTER_NUM"]
        self.DROPOUT_PROB = kwargs["DROPOUT_PROB"]
        self.IN_CHANNEL = 1
        self.KOP = kwargs["KOP"]
        self.TIED = kwargs["TIED"]
        self.WARM_START = kwargs["WARM_START"]
        self.USE_FC = kwargs["FC"]

        self.bn = False # Don't use batch norm

        kernel_size = 5
        nblocks = 6

        assert (len(self.FILTERS) == len(self.FILTER_NUM))

        # one for UNK and one for zero padding
        self.embedding = nn.Embedding(self.VOCAB_SIZE + 2, self.WORD_DIM, padding_idx=self.VOCAB_SIZE + 1)
        if self.MODEL == "static" or self.MODEL == "non-static" or self.MODEL == "multichannel":
            self.WV_MATRIX = kwargs["WV_MATRIX"]
            self.embedding.weight.data.copy_(torch.from_numpy(self.WV_MATRIX))
            if self.MODEL == "static":
                self.embedding.weight.requires_grad = False
            elif self.MODEL == "multichannel":
                self.embedding2 = nn.Embedding(self.VOCAB_SIZE + 2, self.WORD_DIM, padding_idx=self.VOCAB_SIZE + 1)
                self.embedding2.weight.data.copy_(torch.from_numpy(self.WV_MATRIX))
                self.embedding2.weight.requires_grad = False
                self.IN_CHANNEL = 2

        if self.USE_FC:
            self.C1 = nn.Linear(
                self.MAX_SENT_LEN * self.WORD_DIM,
                6 * (self.MAX_SENT_LEN // 2), bias=False)
            self.C3 = nn.Linear(
                6 * (self.MAX_SENT_LEN // 2),
                16 * (self.MAX_SENT_LEN // 4), bias=False)

        elif self.KOP:
            self.C1 = kop.KOP1D(
                self.MAX_SENT_LEN,
                self.IN_CHANNEL * self.WORD_DIM,
                6,
                kernel_size,
                nblocks=nblocks,
                warm_start=self.WARM_START,
                padding=(kernel_size - 1) // 2,
                stride=1).cuda()
            self.C3 = kop.KOP1D(
                self.MAX_SENT_LEN,
                6,
                16,
                kernel_size,
                nblocks=nblocks,
                warm_start=self.WARM_START,
                padding=(kernel_size - 1) // 2,
                stride=1,
                K1=self.C1.K1 if self.TIED else None,
                Kd=self.C1.Kd if self.TIED else None,
                K2=self.C1.K2 if self.TIED else None).cuda()
            self.mode = 'warm_start'

        else:
            self.C1 = nn.Conv1d(
                self.IN_CHANNEL * self.WORD_DIM,
                6,
                kernel_size,
                padding=(kernel_size - 1) // 2, padding_mode='circular',
                bias=False)
            self.C3 = nn.Conv1d(
                6,
                16,
                kernel_size,
                padding=(kernel_size - 1) // 2, padding_mode='circular',
                bias=False)

        self.C5 = nn.Linear(144, 120)
        torch.nn.init.kaiming_normal_(self.C5.weight, nonlinearity='relu')

        self.S2 = nn.AvgPool1d(2, stride=2)
        self.S4 = nn.AvgPool1d(2, stride=2)

        self.F6 = nn.Linear(120, 84)
        torch.nn.init.kaiming_normal_(self.F6.weight, nonlinearity='relu')

        self.OUTPUT = nn.Linear(84, self.CLASS_SIZE)

        torch.nn.init.kaiming_normal_(self.OUTPUT.weight, nonlinearity='relu')

        self.BN1 = nn.BatchNorm1d(6)
        self.BN2 = nn.BatchNorm1d(16)

    def forward(self, inp):
        x = self.embedding(inp)
        x = x.transpose(2, 1)
        if self.MODEL == "multichannel":
            x2 = self.embedding2(inp).view(-1, 1, self.WORD_DIM * self.MAX_SENT_LEN)
            x = torch.cat((x, x2), 1)

        if self.USE_FC:
            x = torch.flatten(x, 1)

        x = self.C1(x)

        if self.bn:
            self.BN1(x)

        x = torch.relu(x)

        if not self.USE_FC:
            x = self.S2(x)

        x = self.C3(x)

        if self.bn:
            self.BN2(x)

        x = torch.relu(x)

        if not self.USE_FC:
            x = self.S4(x)
            x = torch.flatten(x, 1)

        x = self.C5(x)
        x = torch.relu(x)
        x = self.F6(x)
        x = torch.relu(x)
        x = self.OUTPUT(x)

        # Use cross entropy loss
        #x = F.log_softmax(x, dim=1)
        return x
