import torch.nn as nn
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
import numpy as np
import io
import pickle
import sys

TORCH_SEED = 431
torch.manual_seed(TORCH_SEED)
np.random.seed(TORCH_SEED)


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.final_classes = 2
        self.linear1 = nn.Linear(128*128*3, 256)
        self.linear2 = nn.Linear(256, 64)
        #self.linear3 = nn.Linear(128, 64)
        self.final = nn.Linear(64, self.final_classes)
        self.relu = nn.ReLU()
        self.sigm = nn.Sigmoid()

    def forward(self, img):
        x = img.view(-1, 128*128*3)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        #x = self.relu(self.linear3(x))
        x = self.final(x)
        return x


class MetFacesDataSet(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.csv_file = csv_file
        self.final_classes = 2
        self.img_labels = pd.read_csv(self.csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.img_labels.iloc[idx]["img_id"])
        image = plt.imread(img_name)
        image_tensor = self.transform(image)
        label = [int(self.img_labels.iloc[idx]["label"])]
        label_tensor = torch.LongTensor(label)
        return image_tensor, label_tensor


if __name__ == '__main__':
    TORCH_SEED = int(sys.argv[1])
    torch.manual_seed(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:1" if use_cuda else "cpu")

    train_transform = transforms.Compose([transforms.ToTensor(),
                                    #transforms.RandomSizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.Normalize((0.5,), (0.5,), (0.5,))])
    test_transform = transforms.Compose([transforms.ToTensor(),
                                        #transforms.CenterCrop(224),
                                        transforms.Normalize((0.5,), (0.5,), (0.5,))])

    train_data = MetFacesDataSet('train.csv', 'resized_images_train', transform=train_transform)
    test_data = MetFacesDataSet('test.csv', 'resized_images_test', transform=test_transform)

    net = Net()
    net = net.float()
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
    batch_size= 32
    trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
    testloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0)
    num_backs = int(len(trainloader)/batch_size)
    step_backs = num_backs/10
    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []
    num_epochs = 1
    model_dir = "metfaces_models/"
    print(device)
    print("*****************************")
    for epoch in range(num_epochs):
        train_acc = 0.0
        train_loss = 0.0
        test_acc = 0.0
        test_loss = 0.0
        num_acc_train = 0
        num_acc_test = 0

        net.train()
        if epoch%5 == 0:
            torch.save(net, model_dir + "metfaces_model_%s_seed_%s.pkl"%(epoch, TORCH_SEED))
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            output = net(inputs)
            labels_flat = torch.flatten(labels)
            loss = criterion(output.float(), labels_flat)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            _, pred = output.data.to(device).topk(1, dim=1)
            pred_flat = torch.flatten(pred)
            train_acc += torch.mean((pred_flat == labels_flat).float()).item()
            num_acc_train += 1
            if i % step_backs == 0:
                torch.save(net, model_dir + "metfaces_model_%s_seed_%s.pkl" % (i/num_backs, TORCH_SEED))

        with torch.no_grad():
            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                output = net(inputs)
                labels_flat = torch.flatten(labels)
                loss = criterion(output.float(), labels_flat)
                test_loss += loss.item()
                _, pred = output.data.to(device).topk(1, dim=1)
                pred = pred.t()
                pred_flat = torch.flatten(pred)
                test_acc += torch.mean((pred_flat == labels_flat).float()).item()
                num_acc_test += 1


        train_loss = train_loss / len(trainloader)
        test_loss = test_loss / len(testloader)
        train_acc = train_acc / num_acc_train
        test_acc = test_acc / num_acc_test
        test_accs.append(test_acc)
        train_accs.append(train_acc)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        print("Epoch: %s, Train loss: %s, Test loss: %s, Train acc: %s, Test acc: %s" % (epoch, train_loss, test_loss,
                                                                                         train_acc, test_acc))

    with open("training_data/train_test_losses_%s.pkl"%(TORCH_SEED), "wb") as f_out:
        pickle.dump((train_losses, test_losses, train_accs, test_accs), f_out)






