import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
from vgg import vgg16
from transform import transform, unblind, blind, collect_weights, update
from torch.nn import CrossEntropyLoss 
from tqdm import tqdm
from data_set import *
import sys
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--blind', type=int, default=0)
args = parser.parse_args()
is_blind = args.blind == 1
print(is_blind)
# specifying devices
device = torch.device("cuda:0")

num_classes=10
model = vgg16(num_classes=num_classes)

# optimizer
opt = torch.optim.SGD(model.parameters(), lr=1e-2)
# shape specifier
batch_size = 32
max_cpu_thread = 24
input_size  = (batch_size, 3, 224, 224)
output_size = (batch_size, num_classes)
y = torch.randn((2, 64, 224, 224), device=device)


# moving data to gpu
model.to(device).train()
loss_func = CrossEntropyLoss().cuda()

# preprocessing
train_loader = data_loader(1, batch_size)
val_loader = data_loader(1, batch_size, is_train=False)

# log file
filename = 'loss_log_full.txt'
log = open(filename, 'w')

# forward pass
for epoch in range(50):
    i  = 1
    average_loss = 0.0

    for image, target in tqdm(train_loader):
        image = image.cuda()
        target = target.cuda()
        
        y_pred = model(image)
        loss = loss_func(y_pred, target)
        opt.zero_grad()
        loss.backward()
        opt.step()
        average_loss = average_loss + (loss.item() - average_loss) / i

    top1_err = 0.0
    top5_err = 0.0

    for image, target in tqdm(val_loader):
        image = image.cuda()
        target = target.cuda()

        res = model(image)
        pred = torch.argmax(res, dim=1)
        diff = ((pred - target) == 0).float()
        e1 = sum(diff)
                
        top1_err = top1_err + e1
        
    top1_acc = top1_err / 10000.


    log.write(str(epoch))
    log.write(", average_train loss ")
    log.write(str(average_loss))
    log.write(", top1 acc ")
    log.write(str(top1_acc.item()))
    log.write("\n")
    log.flush()
    print("Epoch ", epoch, "average train loss ", average_loss, "top 1 acc ", top1_acc.item())
