# Standard library imports
import os
import uuid
import random
import csv

# Third party library imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import wandb
from tqdm import tqdm
import pandas as pd

# Local Imports
from dataset import celebahq
from clf import Classifier

"""
PREREQUISITES
"""
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Directory names
root_dir = "/export/io85/data/bbharti1/CelebAMask-HQ"
img_dir = os.path.join(root_dir, "CelebA-HQ-img")
model_dir = os.path.join("checkpoints")
os.makedirs(model_dir, exist_ok=True)

# Load training and validation data
batch_size = 64
_csv_file = "train.csv"
train_data, val_data = utils.load_train_val_data(
    root_dir, img_dir, _csv_file, model_type="clf"
)
dataloaders = {
    "train": DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4),
    "val": DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4),
}

# Load untrained classifier to device
clf = Classifier(fine_tune=False)

# Optimizer
optimizer = torch.optim.Adam(clf.parameters(), lr=1e-3)

"""
MAIN FUNCTION
"""
if __name__ == "__main__":
    clf.train_model(
        dataloaders, optimizer, num_epochs=20, save_path=model_dir, log=True, device=device,
    )
