import argparse
import torch
import time
import random
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import pickle
import numpy as np
import archs.resnet_imagenet as imagenet_models
import torchvision.models.vgg as vgg_imagenet_models
import torchvision
import logging
import os
from thop import profile

imagenet_model_names = sorted(
    name
    for name in imagenet_models.__dict__
    if name.islower()
    and not name.startswith("__")
    and name.startswith("resnet")
    and callable(imagenet_models.__dict__[name])
)

model_names = imagenet_model_names

parser = argparse.ArgumentParser(description="Pruning Neural Network by FPVE")

# Common
parser.add_argument(
    "--arch",
    "-a",
    metavar="ARCH",
    default="resnet18",
    help="model architecture: " + " (default: resnet18)",
)
parser.add_argument("--dataset", help="choose datasets ", default="celeba", type=str)
parser.add_argument(
    "--random_seed",
    "-rd",
    default=random.randint(0, 2024),
    type=int,
    help="seed for dataset split",
)

# Directions
parser.add_argument("--load_dir", dest="load_dir", default="./save_model/", type=str)
parser.add_argument(
    "--dataset_dir", metavar="DIR", help="path to dataset", default="./", type=str
)
parser.add_argument("--save_dir", help="path to save directory", default="./save", type=str)

# Training
parser.add_argument(
    "-b",
    "--batch-size",
    default=128,
    type=int,
    metavar="N",
    help="mini-batch size (default: 128)",
)
parser.add_argument(
    "-vb",
    "--valid-batch-size",
    default=128,
    type=int,
    metavar="N",
    help="mini-batch size for validation(default: 128)",
)
parser.add_argument(
    "-j",
    "--workers",
    default=16,
    type=int,
    metavar="N",
    help="number of data loading workers (default: 4)",
)
parser.add_argument(
    "-p",
    "--print-freq",
    default=50,
    type=int,
    metavar="N",
    help="print frequency (default: 500)",
)
parser.add_argument(
    "--lr",
    "--learning-rate",
    default=0.01,
    type=float,
    metavar="LR",
    help="initial learning rate",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
    "--weight-decay",
    "--wd",
    default=1e-4,
    type=float,
    metavar="W",
    help="weight decay (default: 1e-4)",
)

parser.add_argument(
    "--lr_milestone", type=int, nargs="+", help="lr milestone for fine-tune"
)

parser.add_argument("--noadv_add_schedular", action="store_true", default=False)

# FPVE
parser.add_argument(
    "--evolution_epoch", "-ep", default=10, type=int, help="Evolution epoch"
)
parser.add_argument(
    "--fitness_mode",
    type=str,
    default="ACC",
    help="Evaluation mode of individual fitness in sub-EAs",
)
parser.add_argument(
    "--mutation_rate", default=0.1, type=float, help="Mutation rate in sub-EAs"
)
parser.add_argument(
    "--pop_init_rate",
    default=0.95,
    type=float,
    help="init pruning rate when init population",
)
parser.add_argument("--pruning_ratio", type=float, default=0.9)
parser.add_argument("--iterative_steps", default=9, type=int)
parser.add_argument("--pop_size", default=5, type=int, help="population size for FPVE")
parser.add_argument(
    "--FPVE_fitness_data_ratio", default=0, type=float, help="FPVE_fitness_data ratio"
)

# Others

parser.add_argument(
    "-des",
    "--description",
    type=str,
    default="test",
    help="training log in ./output/#description/",
)
parser.add_argument("--gpu", type=int, default=0, help="cuda training")

# Fairness
parser.add_argument(
    "--target-attr",
    type=str,
    default="Attractive",
    help="target-attr: Attractive, Blond_Hair",
)
parser.add_argument(
    "--sensitive-attr", type=str, default="Male", help="sensitive-attr: Male"
)

# for prune_finetune
parser.add_argument("--num_class", type=int, default=2, help="num of classes")
parser.add_argument("--ft_epochs", type=int, default=20, help="finetune epochs")
parser.add_argument(
    "--num_sensitive_class", type=int, default=2, help="num of sensitive classes"
)
parser.add_argument("--adv_mode", action="store_true", default=False, help="using adversarial finetuning mode")
parser.add_argument("--w", type=float, default=0.1, help="debias weight")
parser.add_argument("--p-lr", type=float, default=1e-4, help="predictor learning rate")
parser.add_argument("--a-lr", type=float, default=1e-4, help="adversary learning rate")
parser.add_argument("--w_decay", action="store_true", default=False)
parser.add_argument("--use_projection", action="store_true", default=False)

parser.add_argument(
    "--train_data_ratio",
    default=0,
    type=float,
    help="Rate of training data utilization",
)
parser.add_argument("--scalar", default=1, type=float, help="scalar for fairness")

