# pylint: disable=E1101,R,C
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
import gzip
import pickle
import numpy as np
import sys
sys.path.insert(0, '../../../..')
import k_operation as kop
from torch.autograd import Variable

MNIST_PATH = "s2_mnist.gz"

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

NUM_EPOCHS = 20
BATCH_SIZE = 32
LEARNING_RATE = 5e-4


class ConvNet(nn.Module):

    def __init__(self):
        super().__init__()

        f1 = 20 #32
        f2 = 40 #64

        nblocks = 9 # TODO might try other nblock sizes

        self.c1 = kop.KOP2D(
            in_size=(60, 60),
            in_ch=1,
            out_ch=f1,
            kernel_size=5,
            stride=3,
            warm_start=True,
            nblocks=nblocks)
        self.relu1 = torch.nn.ReLU()
        self.c2 = kop.KOP2D(
            in_size=(19, 19),
            in_ch=f1,
            out_ch=f2,
            kernel_size=5,
            stride=3,
            warm_start=True,
            nblocks=nblocks)
        self.relu2 = torch.nn.ReLU()
        self.out_layer = torch.nn.Linear(f2 * 5**2, 10)

    def forward(self, x):
        x = self.c1(x)
        x = self.relu1(x)
        x = self.c2(x)
        x = self.relu2(x)
        x = x.view(x.shape[0], -1)
        x = self.out_layer(x)
        return x

    def load_params(self, model, requires_grad=True):
        def load_butterfly(p, p_new):
            p.map1.twiddle.data.copy_(p_new.map1.twiddle)
            p.map1.twiddle.requires_grad = requires_grad

            p.map2.twiddle.data.copy_(p_new.map2.twiddle)
            p.map2.twiddle.requires_grad = requires_grad

        load_butterfly(self.c1.K1, model.c1.K1)
        load_butterfly(self.c2.K1, model.c2.K1)

        load_butterfly(self.c1.Kd, model.c1.Kd)
        load_butterfly(self.c2.Kd, model.c2.Kd)

        load_butterfly(self.c1.K2, model.c1.K2)
        load_butterfly(self.c2.K2, model.c2.K2)


def load_data(path, batch_size):

    with gzip.open(path, 'rb') as f:
        dataset = pickle.load(f)

    train_data = torch.from_numpy(
        dataset["train"]["images"][:, None, :, :].astype(np.float32))
    train_labels = torch.from_numpy(
        dataset["train"]["labels"].astype(np.int64))

    # TODO normalize dataset
    # mean = train_data.mean()
    # stdv = train_data.std()

    train_dataset = data_utils.TensorDataset(train_data, train_labels)
    train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_data = torch.from_numpy(
        dataset["test"]["images"][:, None, :, :].astype(np.float32))
    test_labels = torch.from_numpy(
        dataset["test"]["labels"].astype(np.int64))

    test_dataset = data_utils.TensorDataset(test_data, test_labels)
    test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader, train_dataset, test_dataset


def main():

    sgdr_num = 1
    offline = True

    if offline:
        sgdr_num += 1

    for sgdr_iter in range(sgdr_num): # TODO implement SGDR
        is_offline = (sgdr_iter == sgdr_num - 1)

        if is_offline:
            print("OFFLINE EVAL")

        train_loader, test_loader, train_dataset, _ = load_data(
            MNIST_PATH, BATCH_SIZE)
        
        if is_offline:
            classifier_offline = ConvNet()
            classifier_offline.to(DEVICE)
            classifier_offline.load_params(classifier, requires_grad=False)
            classifier = classifier_offline
        else:
            classifier = ConvNet()
            classifier.to(DEVICE)


        criterion = nn.CrossEntropyLoss()
        criterion = criterion.to(DEVICE)

        parameters = list(
            filter(lambda p: p.requires_grad, classifier.parameters()))
        print("#params", sum([x.numel() for x in parameters]))
        
        optimizer = torch.optim.Adam(
            parameters,
            lr=LEARNING_RATE)

        for epoch in range(NUM_EPOCHS):
            for i, (images, labels) in enumerate(train_loader):
                classifier.train()

                images = images.to(DEVICE)
                labels = labels.to(DEVICE)

                optimizer.zero_grad()
                outputs = classifier(images)
                loss = criterion(outputs, labels)
                loss.backward()

                optimizer.step()

                print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
                    epoch+1, NUM_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE,
                    loss.item()), end="")
            print("")
            correct = 0
            total = 0
            for i, (images, labels) in enumerate(test_loader):
                classifier.eval()

                with torch.no_grad():
                    images = images.to(DEVICE)
                    labels = labels.to(DEVICE)

                    outputs = classifier(images)
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).long().sum().item()
            if is_offline:
                print('Offline Accuracy: {0}'.format(100 * correct / total))
            else:
                print('Test Accuracy: {0}'.format(100 * correct / total))


if __name__ == '__main__':
    main()
