from itertools import islice
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torchvision
from experiments.baselines.MLP import MLP
import numpy as np
import time

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

widths = [[784,10],[784,10,10],[784,100,10],[784,100,100,10],[784,100,100,100,10]]

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train = torchvision.datasets.MNIST(root="./tmp", train=True, transform=torchvision.transforms.ToTensor(), download=False)
test = torchvision.datasets.MNIST(root="./tmp", train=False, transform=torchvision.transforms.ToTensor(), download=False)


dataset = {}
dataset['train_input'] = train.data.reshape(train.data.shape[0],-1).float()/256
dataset['train_label'] = train.targets.long()
dataset['test_input'] = test.data.reshape(test.data.shape[0],-1).float()/256
dataset['test_label'] = test.targets.long()

def train_acc():
    batch = 1024
    N_batch = train.data.shape[0] // batch
    correct = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['train_input'][i_batch*batch:(i_batch+1)*batch])
        pred_class = torch.argmax(logits, dim=1)
        correct += torch.sum(pred_class == dataset['train_label'][i_batch*batch:(i_batch+1)*batch])
    return correct/train.data.shape[0]

def train_loss():
    batch = 1024
    N_batch = train.data.shape[0] // batch
    loss = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['train_input'][i_batch*batch:(i_batch+1)*batch])
        loss += torch.nn.CrossEntropyLoss(reduction='sum')(logits, dataset['train_label'][i_batch*batch:(i_batch+1)*batch])
    return loss/train.data.shape[0]

def test_acc():
    batch = 1024
    N_batch = test.data.shape[0] // batch
    correct = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['test_input'][i_batch*batch:(i_batch+1)*batch])
        pred_class = torch.argmax(logits, dim=1)
        correct += torch.sum(pred_class == dataset['test_label'][i_batch*batch:(i_batch+1)*batch])
    return correct/test.data.shape[0]

def test_loss():
    batch = 1024
    N_batch = test.data.shape[0] // batch
    loss = 0
    for i_batch in range(N_batch+1):
        logits = model(dataset['test_input'][i_batch*batch:(i_batch+1)*batch])
        loss += torch.nn.CrossEntropyLoss(reduction='sum')(logits, dataset['test_label'][i_batch*batch:(i_batch+1)*batch])
    return loss/test.data.shape[0]


for width in widths:

    model = MLP(width=width)

    start_time = time.time()
    model.fit(dataset, steps=2000, batch=1024, loss_fn=torch.nn.CrossEntropyLoss(), opt="Adam", lr=1e-2);

    end_time = time.time()
    wall_time = end_time - start_time

    train_accn = train_acc().item()
    test_accn = test_acc().item()
    train_lossn = train_loss().item()
    test_lossn = test_loss().item()

    np.savetxt(f'./results/mlp_width_{width}.txt', [train_lossn, test_lossn, train_accn, test_accn, wall_time])
