import argparse
import pickle

from torch import optim

from model import *
from utils import evaluate_pp_implicit,train_implicit, get_loader


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="CelebA Experiment")
    parser.add_argument(
        "--target_id", default=2, type=int, help="2:attractive/31:smile/33:wavy hair"
    )
    parser.add_argument("--kappa", default=1e-3, type=float, help="kappa")
    parser.add_argument("--max_inner",default=30, type=int, help="inner optimization steps")
    args = parser.parse_args()

    # Load Celeb dataset
    target_id = args.target_id
    with open("celeba/data_frame.pickle", "rb") as handle:
        df = pickle.load(handle)
    train_df = df["train"]
    train_df = train_df.sample(int(0.5 * len(train_df)))
    valid_df = df["val"]
    test_df = df["test"]

    train_dataloader_0 = get_loader(train_df, "celeba/split/train/", target_id, 64, gender=0)
    train_dataloader_1 = get_loader(train_df, "celeba/split/train/", target_id, 64, gender=1)
    valid_dataloader_0 = get_loader(valid_df, "celeba/split/val/", target_id, 64, gender=0)
    valid_dataloader_1 = get_loader(valid_df, "celeba/split/val/", target_id, 64, gender=1)
    test_dataloader_0 = get_loader(test_df, "celeba/split/test/", target_id, 64, gender=0)
    test_dataloader_1 = get_loader(test_df, "celeba/split/test/", target_id, 64, gender=1)

    # model
    fea = ResNet18_Encoder(pretrained=True).cuda()
    clf_0 = LinearModel().cuda()
    clf_1 = LinearModel().cuda()

    optim_fea = optim.Adam(fea.parameters(), lr=1e-5)
    optim_clf_0 = optim.Adam(clf_0.parameters(), lr=1e-3)
    optim_clf_1 = optim.Adam(clf_1.parameters(), lr=1e-3)

    criterion = nn.BCELoss()
    train_implicit(
        fea,
        clf_0,
        clf_1,
        criterion,
        optim_fea,
        optim_clf_0,
        optim_clf_1,
        train_dataloader_0,
        train_dataloader_1,
        kappa=args.kappa,
        n_epoch=20,
        max_inner=args.max_inner
    )
    print("test:")
    evaluate_pp_implicit(fea, clf_0, clf_1, test_dataloader_0, test_dataloader_1)
