import logging
import sys
import os
import time
import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
import pandas as pd
from PIL import Image
import torch_pruning as tp
import random
import archs.resnet_imagenet as imagenet_models
from torch.utils.data import Subset

def set_random_seed(num):
    random.seed(num)
    np.random.seed(num)
    torch.manual_seed(num)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(num)

def get_fairness_data(args):
    set_random_seed(2024)

    if args.dataset == "celeba":
        data_dir = args.dataset_dir
        normalize = transforms.Normalize(
            mean=[0.5063486, 0.4258108, 0.38318512],
            std=[0.26577517, 0.24520662, 0.24129295],
        )

        train_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(15),
                transforms.ToTensor(),
                normalize,
            ]
        )

        val_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                normalize,
            ]
        )

        train_dataset = datasets.CelebA(
            root=data_dir,
            split="train",
            target_type="attr",
            transform=train_transform,
            download=True,
        )
        target_idx = train_dataset.attr_names.index(args.target_attr)
        sensitive_idx = train_dataset.attr_names.index(args.sensitive_attr)

        if args.train_data_ratio > 0:
            new_train_dataset, _ = torch.utils.data.random_split(
                    train_dataset,
                    [
                        int(args.train_data_ratio * len(train_dataset)),
                        len(train_dataset)
                        - int((args.train_data_ratio) * len(train_dataset)),
                    ],
                )

            print(
                "Train Data split {}/{}".format(
                    len(new_train_dataset), len(train_dataset)
                )
            )
            train_loader = DataLoader(
                dataset=new_train_dataset,
                batch_size=args.batch_size,
                num_workers=4,
                shuffle=True,
            )
        else:
            train_loader = DataLoader(
                dataset=train_dataset,
                batch_size=args.batch_size,
                num_workers=4,
                shuffle=True,
            )

        val_dataset = datasets.CelebA(
            root=data_dir,
            split="valid",
            target_type="attr",
            transform=val_transform,
            download=True,
        )
        test_dataset = datasets.CelebA(
            root=data_dir,
            split="test",
            target_type="attr",
            transform=val_transform,
            download=True,
        )

        val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=args.batch_size,
            num_workers=4,
            shuffle=False,
        )
        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=args.batch_size,
            num_workers=4,
            shuffle=False,
        )
    else:
        raise NotImplementedError("Not supported dataset")

    set_random_seed(args.random_seed)
    return train_loader, val_loader, test_loader, target_idx, sensitive_idx

def get_model(args):
    imagenet_model_names = sorted(
        name
        for name in imagenet_models.__dict__
        if name.islower()
        and not name.startswith("__")
        and name.startswith("resnet")
        and callable(imagenet_models.__dict__[name])
    )
    if args.arch in imagenet_model_names:
        model = imagenet_models.__dict__[args.arch](pretrained=False, num_classes=args.num_class)
        ckpt = torch.load(args.load_dir, map_location=torch.device("cpu"))
        model.load_state_dict(ckpt["state_dict"])
    else:
        raise NotImplementedError("Not supported architecture")
    model.cuda(args.gpu)
    return model

def set_logger(args, name=""):
    logger = logging.getLogger("train_logger")
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s: - %(message)s", datefmt="%m-%d %H:%M"
    )
    fh = logging.FileHandler(
        f'{args.save_dir}/{args.arch}_{time.strftime("%m-%d", time.localtime())}_{name}.log'
    )
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.addHandler(fh)
    logger.info("PyThon  version : {}".format(sys.version.replace("\n", " ")))
    logger.info("PyTorch version : {}".format(torch.__version__))
    logger.info("cuDNN   version : {}".format(torch.backends.cudnn.version()))
    logger.info("Vision  version : {}".format(torchvision.__version__))
    return logger