import os
import datetime
import numpy as np
import sys
import argparse
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from collections import OrderedDict
import torch.nn.functional as F
import time
import path
folder_path= (path.Path(__file__).abspath()).parent.parent
sys.path.append(folder_path)
# print(folder_path)
folder_path = folder_path.parent
sys.path.insert(0, folder_path)
from classifier_base import Classifier
from data.pytorch_datasets import *
from resnet import ResNet18
from nn_cifar10 import NN_CIFAR10

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

torch.manual_seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser(description='PyTorch CIFAR MART Defense')
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-t', type=int, default=120)
parser.add_argument('--model_dir', type=str, default='saved_models/cifar10/2022-04-28 09:56')

if __name__=="__main__":
    for p in sys.path:
        print(p)
    args = parser.parse_args()
    test_ds = get_dataset(args)[1]
    max_acc = 0
    best_idx = 0
    for t in range(1, args.t+1):
        load_path = os.path.join(args.model_dir, f"model-{t}-timestep")
        net = NN_CIFAR10()
        net.load_model(load_path)
        dl = DataLoader(test_ds, batch_size=500)
        correct = 0
        total_loss = 0
        for idx, (imgs, lbls) in enumerate(dl):
            acc, loss = list(net.evaluate(imgs, lbls).values())
            correct += acc*len(imgs)
            total_loss += loss*len(imgs)
        acc = correct/len(test_ds)
        total_loss = total_loss/len(test_ds)
        if acc > max_acc:
            max_acc = acc
            best_idx = t
        print(f"Model {t}, accuracy: {acc}, loss: {total_loss}")
    
    print(f"Best model is model-{best_idx} with accuracy {max_acc}")