import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)
from functools import partial
import archs.resnet_imagenet as imagenet_models
import argparse
import logging
import random
import torch.optim.lr_scheduler as lr_scheduler
import torch
from utils.train_eval import *
import utils.tp_utils as tp_utils
import warnings

warnings.filterwarnings("ignore")
import torch_pruning as tp
import math


imagenet_model_names = sorted(
    name
    for name in imagenet_models.__dict__
    if name.islower()
    and not name.startswith("__")
    and name.startswith("resnet")
    and callable(imagenet_models.__dict__[name])
)
def get_model(args):
    if args.arch in imagenet_model_names:
        model = imagenet_models.__dict__[args.arch](pretrained=False, num_classes=args.num_class)
    else:
        raise NotImplementedError("Not supported architecture")
    return model

parser = argparse.ArgumentParser(description="ResNets for CelebA in pytorch")
# basic option
parser.add_argument("--mode", type=str, required=True, choices=["train", "test"])
parser.add_argument("--dataset", type=str, default="celeba", choices=["celeba"])
parser.add_argument("--dataset_dir", type=str, default="../datasets/")
parser.add_argument("--arch", type=str, default="resnet18", help="model architecture")
parser.add_argument(
    "--target_attr",
    type=str,
    default="Attractive",
    help="target-attr: Attractive, Blond_Hair, etc.",
)
parser.add_argument(
    "--sensitive_attr", type=str, default="Male", help="sensitive-attr: Male"
)
parser.add_argument("--fitness", type=str, default="DEO", help="fitness: DEO, DI")
parser.add_argument(
    "--random_seed", "-rd", default=2, type=int, help="seed for dataset split"
)
parser.add_argument("--use_schedular", action="store_true", default=False)
parser.add_argument("--num_class", type=int, default=2, help="num of classes")
parser.add_argument(
    "--num_sensitive_class", type=int, default=2, help="num of sensitive classes"
)

parser.add_argument(
    "--workers",
    "-j",
    default=4,
    type=int,
    metavar="N",
    help="number of data loading workers (default: 4)",
)
parser.add_argument(
    "--epochs", default=50, type=int, metavar="N", help="number of total epochs to run"
)
parser.add_argument(
    "--start_epoch",
    default=0,
    type=int,
    metavar="N",
    help="manual epoch number (useful on restarts)",
)
parser.add_argument(
    "--batch_size",
    "-b",
    default=128,
    type=int,
    metavar="N",
    help="mini-batch size (default: 128)",
)
parser.add_argument(
    "--print_freq",
    "-p",
    default=50,
    type=int,
    metavar="N",
    help="print frequency (default: 50)",
)
parser.add_argument(
    "--save_every",
    dest="save_every",
    help="Saves checkpoints at every specified number of epochs",
    type=int,
    default=10000,
)

parser.add_argument(
    "--resume",
    default="",
    type=str,
    metavar="PATH",
    help="the model to be trained. TIPS: an exmaple: ../models/DIRNAME/MODELNAME.pth",
)
parser.add_argument(
    "--save_dir",
    dest="save_dir",
    help="The directory used to save the pruned models",
    default="../models",
    type=str,
)
parser.add_argument("--gpu", type=int, default=3, help="cuda training")

parser.add_argument(
    "--train_data_ratio",
    default=0,
    type=float,
    help="Rate of training data utilization",
)

# global variant
best_prec_val = 0
best_prec_test = 0
best_epoch = 0
args = parser.parse_args()


def main():
    global args, best_prec_val, best_prec_test, best_epoch

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.random_seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    num = 0
    temp_save_dir = os.path.join(
        args.save_dir,
        args.dataset,
        args.arch,
        args.target_attr,
        args.mode,
        "exp"
        + str(num).rjust(3, "0")
        + "data_ratio"
        + str(args.train_data_ratio),
    )

    while os.path.exists(temp_save_dir):
        num += 1
        temp_save_dir = os.path.join(
            args.save_dir,
            args.dataset,
            args.arch,
            args.target_attr,
            args.mode,
            "exp"
            + str(num).rjust(3, "0")
            + "data_ratio"
            + str(args.train_data_ratio),
        )
    args.save_dir = temp_save_dir
    os.makedirs(args.save_dir)

    logger = logging.getLogger("train_logger")
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s: - %(message)s", datefmt="%m-%d %H:%M"
    )

    fh = logging.FileHandler(
        f'{args.save_dir}/{args.arch}_{time.strftime("%m-%d", time.localtime())}_{args.random_seed}.log'
    )
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.addHandler(fh)

    if args.mode == "train":
        if args.resume:
            logger.info("train- {}".format(args.resume))
        else:
            logger.info("train_from_scratch")

    elif args.mode == "test":
        logger.info("test- {}".format(args.resume))
    else:
        raise ValueError("Please provide a mode: train or test")

    logger.info(args.__dict__)
    train_loader, val_loader, test_loader, target_idx, sensitive_idx = (
        tp_utils.get_fairness_data(args)
    )
    model = get_model(args)

    mode_train_prune_save = False
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading chechpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            if checkpoint.get("best_prec1", -1) == -1:
                tp.load_state_dict(model, state_dict=checkpoint)
                logger.info("=> loaded pruned model")
                model.zero_grad()
                model.reset_parameters()
                print(model)
                logger.info("=> pruned model reinitialize")
                logger.info(model)
                mode_train_prune_save = True

            else:
                args.start_epoch = checkpoint["epoch"]
                best_prec_val = checkpoint["best_prec1"]
                model.load_state_dict(checkpoint["state_dict"])
                logger.info("best_prec1:{}".format(best_prec_val))
                logger.info(
                    "=> loaded checkpoint  (epoch {})".format(checkpoint["epoch"])
                )
                mode_train_prune_save = False
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            raise ValueError("No checkpoint found")

    else:
        if args.mode != "train":
            raise ValueError("Please provide a checkpoint to test the model")

    if args.gpu is not None:
        model.cuda(args.gpu)

    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001
    )
    if args.use_schedular:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        criterion.cuda(args.gpu)

    if args.mode == "train":
        train_acc_list = []

        valid_acc_list = []
        valid_DEO_list = []
        valid_DI_list = []

        test_acc_list = []
        test_DEO_list = []
        test_DI_list = []

        for epoch in range(args.start_epoch, args.epochs):
            prec_train, di_train, deo_train = train(
                train_loader,
                model,
                criterion,
                optimizer,
                epoch,
                target_idx,
                sensitive_idx,
                args,
            )
            train_acc_list.append(prec_train)

            prec_val, di_val, deo_val = fairness_validate(
                val_loader,
                model,
                criterion,
                args,
                target_idx,
                sensitive_idx,
                mode="Valid",
            )
            is_best_val = prec_val > best_prec_val
            valid_acc_list.append(prec_val)
            valid_DEO_list.append(deo_val)
            valid_DI_list.append(di_val)

            prec_test, di_test, deo_test = fairness_validate(
                test_loader,
                model,
                criterion,
                args,
                target_idx,
                sensitive_idx,
                mode="Test",
            )
            is_best_test = prec_test > best_prec_test
            test_acc_list.append(prec_test)
            test_DEO_list.append(deo_test)
            test_DI_list.append(di_test)

            best_prec_val = max(prec_val, best_prec_val)
            best_prec_test = max(prec_test, best_prec_test)

            if args.use_schedular:
                scheduler.step()

            if not mode_train_prune_save:
                if is_best_val:
                    best_epoch = epoch
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_val,
                        },
                        is_best_val,
                        file_name=os.path.join(args.save_dir, "best_model.pth"),
                    )

                if is_best_test:
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_test,
                        },
                        is_best_test,
                        file_name=os.path.join(args.save_dir, "best_model_test.pth"),
                    )

                if epoch > 0 and (epoch + 1) % args.save_every == 0:
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_val,
                        },
                        is_best_val,
                        file_name=os.path.join(
                            args.save_dir, "checkpoint_ep_{}.pth".format(epoch + 1)
                        ),
                    )

                save_checkpoint(
                    {
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec_val,
                    },
                    is_best_val,
                    file_name=os.path.join(args.save_dir, "last_model.pth"),
                )
            else:
                state_dict = tp.state_dict(model)
                if is_best_val:
                    best_epoch = epoch
                    torch.save(
                        state_dict,
                        os.path.join(
                            args.save_dir, "pruned_model_finetune_val_best.pth"
                        ),
                    )
                if is_best_test:
                    torch.save(
                        state_dict,
                        os.path.join(
                            args.save_dir, "pruned_model_finetune_test_best.pth"
                        ),
                    )

                if epoch > 0 and epoch % args.save_every == 0:
                    torch.save(
                        state_dict, os.path.join(args.save_dir, "checkpoint.pth")
                    )
                torch.save(
                    state_dict,
                    os.path.join(args.save_dir, "pruned_model_finetune_last.pth"),
                )

            logger.info(f"train_acc:{train_acc_list}")
            logger.info(f"valid_acc:{valid_acc_list}")
            logger.info(f"test_acc:{test_acc_list}")

            logger.info(f"valid_DI:{valid_DI_list}")
            logger.info(f"valid_DEO:{valid_DEO_list}")

            logger.info(f"test_DI:{test_DI_list}")
            logger.info(f"test_DEO:{test_DEO_list}\n\n")

            logger.info(f"best epoch: {best_epoch}")
            logger.info(
                f"best val model | valid acc: {valid_acc_list[best_epoch]}, valid deo: {valid_DEO_list[best_epoch]}"
            )
            logger.info(
                f"best val model | test acc: {test_acc_list[best_epoch]}, valid deo: {test_DEO_list[best_epoch]}"
            )

    elif args.mode == "test":
        prec_test, di_test, deo_test = fairness_validate(
            test_loader, model, criterion, args, target_idx, sensitive_idx, mode="Test"
        )


if __name__ == "__main__":
    main()
