import os
import torch
from tqdm import tqdm
import pandas as pd
from models import VGG9
from train import train
from test import test


import torchvision
import torchvision.transforms as transforms
import torch
import os
from scipy.io import loadmat

means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ]
)


def get_stl10_loaders(root_path="./data", **kwargs):

    trainset = torchvision.datasets.STL10(
        root=root_path, split="train", download=True, transform=train_transform
    )
    testset = torchvision.datasets.STL10(
        root=root_path, split="test", download=True, transform=test_transform
    )
    trainloader = torch.utils.data.DataLoader(trainset, **kwargs)
    testloader = torch.utils.data.DataLoader(testset, **kwargs)
    return trainloader, testloader


trainloader, testloader = get_stl10_loaders(batch_size=64, shuffle=True, num_workers=8)

learning_rate = 5e-4
epochs = 100
nb_runs = 5
results = pd.DataFrame(
    columns=["clop_position", "train_accuracy", "test_accuracy", "run_id"]
)

for i in tqdm(range(nb_runs), desc="run_loop", leave=False):
    for j in [2, 5, 7, 10, 13, 15, 16]:
        # plain model
        model = VGG9(position=j).cuda()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        train_accuracy = train(
            model, optimizer, trainloader, epochs=epochs, scheduler=scheduler
        )
        test_accuracy = test(model=model, testloader=testloader)
        results = results.append(
            {
                "run_id": i,
                "clop_position": j,
                "test_accuracy": test_accuracy,
                "train_accuracy": train_accuracy,
            },
            ignore_index=True,
        )

        outdir = f"./results/supervised"
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        path = os.path.join(outdir, "stl10_position.csv")
        results.to_csv(path)
