from FPVE_config import *
import sys
from collections import OrderedDict
import torchvision.models as models

def set_random_seed(num):
    random.seed(num)
    np.random.seed(num)
    torch.manual_seed(num)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(num)


def get_fairness_data_FPVE(args):
    set_random_seed(2024)

    if args.dataset == "celeba":
        data_dir = args.dataset_dir
        normalize = transforms.Normalize(
            mean=[0.5063486, 0.4258108, 0.38318512],
            std=[0.26577517, 0.24520662, 0.24129295],
        )

        train_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(15),
                transforms.ToTensor(),
                normalize,
            ]
        )

        val_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                normalize,
            ]
        )

        train_dataset = datasets.CelebA(
            root=data_dir,
            split="train",
            target_type="attr",
            transform=train_transform,
            download=False,
        )
        target_idx = train_dataset.attr_names.index(args.target_attr)

        sensitive_idx = train_dataset.attr_names.index("Male")

        if args.train_data_ratio > 0:
            new_train_dataset, _ = torch.utils.data.random_split(
                    train_dataset,
                    [
                        int(args.train_data_ratio * len(train_dataset)),
                        len(train_dataset)
                        - int((args.train_data_ratio) * len(train_dataset)),
                    ],
                )

            print(
                "Train Data split {}/{}".format(len(new_train_dataset), len(train_dataset))
            )
            train_loader = DataLoader(
                dataset=new_train_dataset,
                batch_size=args.batch_size,
                num_workers=4,
                pin_memory=True,
                shuffle=True,
            )
        else:
            train_loader = DataLoader(
                dataset=train_dataset,
                batch_size=args.batch_size,
                num_workers=4,
                pin_memory=True,
                shuffle=True,
            )

        val_dataset = datasets.CelebA(
            root=data_dir,
            split="valid",
            target_type="attr",
            transform=val_transform,
            download=False,
        )
        test_dataset = datasets.CelebA(
            root=data_dir,
            split="test",
            target_type="attr",
            transform=val_transform,
            download=False,
        )

        val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=args.valid_batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=False,
        )
        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=args.valid_batch_size,
            num_workers=4,
            pin_memory=True,
            shuffle=False,
        )

        if args.FPVE_fitness_data_ratio > 0:
            train_dataset2 = datasets.CelebA(
                root=data_dir,
                split="train",
                target_type="attr",
                transform=val_transform,
                download=False,
            )

            if args.train_data_ratio > 0:
                new_train_dataset2, _ = torch.utils.data.random_split(
                        train_dataset2,
                        [
                            int(args.train_data_ratio * len(train_dataset2)),
                            len(train_dataset2)
                            - int((args.train_data_ratio) * len(train_dataset2)),
                        ],
                    )
                FPVE_fitness_dataset, _ = torch.utils.data.random_split(
                        new_train_dataset2,
                        [
                            int(args.FPVE_fitness_data_ratio * len(new_train_dataset2)),
                            len(new_train_dataset2)
                            - int((args.FPVE_fitness_data_ratio) * len(new_train_dataset2)),
                        ],
                    )

            else:
                FPVE_fitness_dataset, _ = torch.utils.data.random_split(
                    train_dataset2,
                    [
                        int(args.FPVE_fitness_data_ratio * len(train_dataset2)),
                        len(train_dataset2)
                        - int((args.FPVE_fitness_data_ratio) * len(train_dataset2)),
                    ],
                )
            print(
                "Fitness Data split {}/{}".format(
                    len(FPVE_fitness_dataset), len(train_dataset2)
                )
            )
            FPVE_fitness_loader = DataLoader(
                dataset=FPVE_fitness_dataset,
                batch_size=args.valid_batch_size,
                num_workers=4,
                pin_memory=True,
                shuffle=True,
            )
    else:
        raise NotImplementedError("Not supported dataset")

    set_random_seed(args.random_seed)

    if args.FPVE_fitness_data_ratio > 0:
        return (
            train_loader,
            val_loader,
            test_loader,
            target_idx,
            sensitive_idx,
            FPVE_fitness_loader,
        )
    else:
        return train_loader, val_loader, test_loader, target_idx, sensitive_idx

def get_model(args):
    if args.arch in imagenet_model_names:
        model = imagenet_models.__dict__[args.arch](pretrained=False, num_classes=args.num_class)
        ckpt = torch.load(args.load_dir, map_location=torch.device("cpu"))
        model.load_state_dict(ckpt["state_dict"])
    else:
        raise NotImplementedError("Not supported architecture")
    model.cuda(args.gpu)
    return model

def set_logger(args):
    logger = logging.getLogger("train_logger")
    logger.setLevel(logging.DEBUG)

    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s: - %(message)s", datefmt="%m-%d %H:%M"
    )
    c = ""

    args.save_dir = (
        args.save_dir
        + "/"
        + f"pruning_{args.dataset}_{args.arch}_{c}_"
        + time.strftime("%m_%d_%H_%M_%S", time.localtime())
        + f"_{args.pop_init_rate}_{args.ft_epochs}_{args.random_seed}"
    )
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    fh = logging.FileHandler(
        f'{args.save_dir}/pruning_{args.dataset}_{args.arch}_{c}_{time.strftime("%m-%d", time.localtime())}.log'
    )
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)

    logger.addHandler(ch)
    logger.addHandler(fh)

    logger.info("PyThon  version : {}".format(sys.version.replace("\n", " ")))
    logger.info("PyTorch version : {}".format(torch.__version__))
    logger.info("cuDNN   version : {}".format(torch.backends.cudnn.version()))
    logger.info("Vision  version : {}".format(torchvision.__version__))
    logger.info(f"Network: {args.arch}")
    logger.info(f"Dataset: {args.dataset}")
    if args.dataset != "ImageNet":
        logger.info(f"Original Model: {args.load_dir}")
    logger.info(f"Learning rate: {args.lr}")
    logger.info(f"Fine-tune after pruning: {args.ft_epochs}")
    logger.info(f"Fine-tune lr milestone: {args.lr_milestone}")
    logger.info(f"Iterative Steps: {args.iterative_steps}")
    logger.info(f"Population Init Rate: {args.pop_init_rate}")
    logger.info(f"Pruning Ratio: {args.pruning_ratio}")
    logger.info(f"Evolution Round: {args.evolution_epoch}")
    logger.info(f"Random seed:{args.random_seed}")
    logger.info(f"Population Size: {args.pop_size}")

