# Standard library imports
import os
import uuid
import random
import csv

# Third party library imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import wandb
from tqdm import tqdm
import pandas as pd

# Local Imports
import tools
from dataset import CelebAHQ
from clf import Classifier
from explainer import Explainer

"""
PREREQUISITES
"""
# Set device
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

# Directory names
root_dir = "/export/io85/data/bbharti1/CelebAMask-HQ"
data_dir = "data"
img_dir = os.path.join(root_dir, "CelebA-HQ-img")
model_dir = os.path.join("checkpoints", "final_models")
os.makedirs(model_dir, exist_ok=True)

# Load training and validation data
batch_size = 32
_csv_file = "train.csv"
train_data, val_data = tools.load_train_val_data(
    root_dir, img_dir, _csv_file, model_type="exp"
)
dataloaders = {
    "train": DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4),
    "val": DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4),
}

# Load background
bkgd = torch.load(os.path.join(data_dir, "mean_img.pt"))

# Load trained classifier
clf_state_dict_path = os.path.join(model_dir, "clf.pt")
clf = tools.load_classifier(clf_state_dict_path, eval_mode=True)

# Experiment parameters
num_samples = 5
num_epochs = 20
_y_0 = 0.18

"""
MAIN FUNCTION
"""
if __name__ == "__main__":
    alpha = 0
    sp_mults = [0.05, 0.1, 0.15, 0.2]
    sm_mults = [0.05, 0.1, 0.15, 0.2]
    sh_mults = [0]
    for sp_mult in sp_mults:
        for sm_mult in sm_mults:
            for sh_mult in sh_mults:
                print(f"Learning model with sp_mult = {sp_mult}, sm_mult = {sm_mult}, sh_mult = {sh_mult}")
                exp = tools.load_explainer(clf, device, bkgd, num_channels=3, state_dict_path=None, eval_mode=False)
                optimizer = torch.optim.Adam(exp.parameters(), lr=1e-3)
                model_name = (
                    "explainer_"
                    + str(alpha)
                    + "_"
                    + str(sp_mult)
                    + "_"
                    + str(sm_mult)
                    + "_"
                    + str(sh_mult)
                    + ".pt"
                )
                # Train
                exp.train_model(
                    alpha,
                    _y_0,
                    sp_mult,
                    sm_mult,
                    sh_mult,
                    dataloaders,
                    optimizer,
                    num_samples,
                    num_epochs,
                    save_path=model_dir,
                    model_name=model_name,
                    log_step=25,
                    log=False,
                )
