import os
import torch
from tqdm import tqdm
import pandas as pd
from models import VGG11
from train import train, train_mixup
from test import test
from data_utils import get_imagenette2_loaders


trainloader, testloader = get_imagenette2_loaders( batch_size=64, shuffle=True, num_workers=8)

learning_rate = 5e-4
epochs = 90
nb_runs = 5
results = pd.DataFrame(columns=['regul', 'train_accuracy', 'test_accuracy', 'run_id'])

for i in tqdm(range(nb_runs), desc='run_loop', leave=False):
    #plain model
    model = VGG11().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'No_regul', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    #batchnorm model
    model = VGG11(regul='batch_norm').cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'batch_norm', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    # dropout model
    model = VGG11(regul='dropout', p=0.6).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'dropout', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    # CLOP model
    model = VGG11(regul='clop', p=0.6).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'clop', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    # mixup
    model = VGG11().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train_mixup(
        model, optimizer, trainloader, epochs=epochs, scheduler=scheduler, alpha=0.3
    )
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append(
        {
            "run_id": i,
            "regul": "mixup",
            "test_accuracy": test_accuracy,
            "train_accuracy": train_accuracy,
        },
        ignore_index=True,
    )

    outdir = f'./results/supervised'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    path = os.path.join(outdir, 'imagenette.csv')
    results.to_csv(path)