import torch.nn as nn
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
import numpy as np
from typing import Any, Callable, Optional, Tuple
import io
import pickle
import sys
from collections import Counter

TORCH_SEED = 431
torch.manual_seed(TORCH_SEED)
np.random.seed(TORCH_SEED)


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.final_classes = 2
        self.linear1 = nn.Linear(28*28*3, 128)
        self.linear2 = nn.Linear(128, 64)
        #self.linear3 = nn.Linear(128, 64)
        self.final = nn.Linear(64, self.final_classes)
        self.relu = nn.ReLU()
        self.sigm = nn.Sigmoid()

    def forward(self, img):
        x = img.view(-1, 28*28*3)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        #x = self.relu(self.linear3(x))
        x = self.final(x)
        return x


class CIFARModifiedDataset(datasets.CIFAR10):

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image, label = super().__getitem__(index)
        new_label = 0
        if label >= 2 and label <= 7:
            new_label = 1
        return image, new_label


if __name__ == '__main__':
    TORCH_SEED = int(sys.argv[1])
    torch.manual_seed(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    net = Net()
    net = net.float()
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-4)

    transform_train = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.RandomSizedCrop(28),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_test = transforms.Compose(
        [transforms.CenterCrop(28),
            transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    batch_size = 32

    train_data = CIFARModifiedDataset(root='./data', train=True,
                                            download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    test_data = CIFARModifiedDataset(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    step_backs = int(len(trainloader) / 10)
    num_backs = len(trainloader)
    print(step_backs, num_backs, len(trainloader))

    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []
    num_epochs = 1
    model_dir = "cifar_models_epoch_1/"

    print("*****************************")
    for epoch in range(num_epochs):
        train_acc = 0.0
        train_loss = 0.0
        test_acc = 0.0
        test_loss = 0.0
        num_acc_train = 0
        num_acc_test = 0

        net.train()
        labels_flat_arr = np.array([])

        torch.save(net, model_dir + "metfaces_model_%s_seed_%s.pkl"%(epoch, TORCH_SEED))
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            output = net(inputs)
            labels_flat = torch.flatten(labels)
            loss = criterion(output.float(), labels_flat)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            _, pred = output.data.to(device).topk(1, dim=1)
            pred_flat = torch.flatten(pred)
            train_acc += torch.mean((pred_flat == labels_flat).float()).item()
            num_acc_train += 1
            if i % step_backs == 0:
                torch.save(net, model_dir + "metfaces_model_%s_seed_%s.pkl" % (round(i/num_backs, 1), TORCH_SEED))

        with torch.no_grad():
            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                output = net(inputs)
                labels_flat = torch.flatten(labels)
                loss = criterion(output.float(), labels_flat)
                test_loss += loss.item()
                _, pred = output.data.to(device).topk(1, dim=1)
                pred = pred.t()
                pred_flat = torch.flatten(pred)
                test_acc += torch.mean((pred_flat == labels_flat).float()).item()
                num_acc_test += 1


        train_loss = train_loss / len(trainloader)
        test_loss = test_loss / len(testloader)
        train_acc = train_acc / num_acc_train
        test_acc = test_acc / num_acc_test
        test_accs.append(test_acc)
        train_accs.append(train_acc)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        print("Epoch: %s, Train loss: %s, Test loss: %s, Train acc: %s, Test acc: %s" % (epoch, train_loss, test_loss,
                                                                                        train_acc, test_acc))

    with open("training_data_cifar/train_test_losses_%s.pkl"%(TORCH_SEED), "wb") as f_out:
        pickle.dump((train_losses, test_losses, train_accs, test_accs), f_out)






