import os
import torch
from tqdm import tqdm
import pandas as pd
from models import MNISTClassifier
from train import train, train_mixup
from test import test
from data_utils import get_mnist_loaders, load_usps


mnist_loader, _ = get_mnist_loaders(batch_size=128, shuffle=True, num_workers=8)
uspsloader = load_usps(batch_size=128, shuffle=True, num_workers=8)

learning_rate = 5e-4
epochs = 30
nb_runs = 10
results = pd.DataFrame(columns=['regul', 'train_accuracy', 'test_accuracy', 'run_id'])

for i in tqdm(range(nb_runs), desc='run_loop', leave=False):
    #plain model
    model = MNISTClassifier(img_size=(1,32,32)).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    results = results.append({"run_id": i, 
                            "regul": 'No_regul', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    # #batchnorm model
    # model = MNISTClassifier(img_size=(1,32,32), regul='batch_norm').cuda()
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    # train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    # test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    # results = results.append({"run_id": i, 
    #                         "regul": 'batch_norm', 
    #                         "test_accuracy": test_accuracy,
    #                         "train_accuracy": train_accuracy}, ignore_index=True)

    # #dropout model
    # model = MNISTClassifier(img_size=(1,32,32), regul='dropout', p=0.9).cuda()
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    # train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    # test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    # results = results.append({"run_id": i, 
    #                         "regul": 'dropout', 
    #                         "test_accuracy": test_accuracy,
    #                         "train_accuracy": train_accuracy}, ignore_index=True)

    # #CLOP model
    # model = MNISTClassifier(img_size=(1,32,32), regul='clop', p=0.9).cuda()
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    # train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    # test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    # results = results.append({"run_id": i, 
    #                         "regul": 'clop', 
    #                         "test_accuracy": test_accuracy,
    #                         "train_accuracy": train_accuracy}, ignore_index=True)

    # mixup
    model = MNISTClassifier(img_size=(1,32,32)).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    results = results.append({"run_id": i, 
                            "regul": 'mixup', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

outdir = f'./results/supervised'
if not os.path.exists(outdir):
    os.makedirs(outdir)
path = os.path.join(outdir, 'mnist.csv')
results.to_csv(path)