import gc
import shutil
import time

import torch
import os
import argparse

from process import process

os.environ['OMP_NUM_THREADS'] = '1'
import warnings

# ignore RuntimeWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# MNIST-USPS
# BDGP
# Fashion
# NUSWIDE
parser = argparse.ArgumentParser(description='train')
parser.add_argument('--dataset', default='NUSWIDE')
parser.add_argument('--batch_size', default=4096, type=int)
parser.add_argument("--temperature_f", default=0.5)
parser.add_argument("--learning_rate", default=0.0003)
parser.add_argument("--weight_decay", default=0.)
parser.add_argument("--main_epochs", default=3)  # local training rounds
parser.add_argument("--output_dim", default=20)  # d_m and d
parser.add_argument("--num_users", default=24)  # number of clients
parser.add_argument("--Dirichlet_alpha", default=9999)
parser.add_argument("--interval_epoch", default=32)
parser.add_argument("--participate", default=1)  # client participation rates
parser.add_argument('--lay_nums', nargs='*', type=int, default=[6, 6, 6])
parser.add_argument('--l2_factor', default=0)
parser.add_argument('--seed', default=10)
parser.add_argument('--repeated_experiment_epoch', default=5)

args = parser.parse_args()
device = torch.device("cuda")
torch.cuda.set_device(0)


def clear_cache():
    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def delete_files_in_directory(directory):
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        try:
            if os.path.isfile(file_path):
                os.remove(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")

if __name__ == '__main__':
    img_dir = 'images'
    if os.path.exists(img_dir):
        delete_files_in_directory(img_dir)
        print(f"{img_dir}")
    else:
        print(f"{img_dir} not exists")
        os.mkdir(img_dir)

    log_dir = 'runs'
    if os.path.exists(log_dir):
        delete_files_in_directory(log_dir)
        time.sleep(1)
        print(f"{log_dir}")
    else:
        print(f"{log_dir} not exists")

    def set_arg(dataset_name):
        args.dataset = dataset_name
        if dataset_name == "MNIST-USPS":
            args.num_users = 24
            args.main_epochs = 50
            args.interval_epoch = 5
            args.lay_nums = [6, 6]
            args.l1_factor = 0
            args.l2_factor = 0
            args.noise_std = 0.
            args.mask_rates = []
            args.output_dim = 20
        if dataset_name == "BDGP":
            args.num_users = 12
            args.main_epochs = 20
            args.interval_epoch = 5
            args.lay_nums = [7, 2]
            args.l1_factor = 0
            args.l2_factor = 0
            args.noise_std = 0.1
            args.mask_rates = []
            args.output_dim = 20
        if dataset_name == "Fashion":
            args.num_users = 48
            args.main_epochs = 200
            args.interval_epoch = 5
            args.lay_nums = [6, 6, 6]
            args.l1_factor = 0
            args.l2_factor = 0
            args.noise_std = 0.
            args.mask_rates = [[2, 2, 1]]
            args.output_dim = 20
        if dataset_name == "NUSWIDE":
            args.num_users = 24
            args.main_epochs = 300
            args.interval_epoch = 5
            args.lay_nums = [2, 4, 3, 2, 3]
            args.l1_factor = 0
            args.l2_factor = 0
            args.noise_std = 0.
            args.mask_rates = [[1, 1, 1, 1, 1]]
            args.output_dim = 20

        args.mask_rate = 0
        args.missing_rate = 0.5
        args.sample_num = 0
        args.participate = 1
        args.ablation_index = 0
        args.epochs1 = 50
        args.epochs2 = 10
        args.global_epochs3 = 30
        args.local_epochs3 = 5
        args.epsilon = 0
        # args.epochs1 = 2
        # args.epochs2 = 2
        # args.global_epochs3 = 2
        # args.local_epochs3 = 2

        args.p1 = 0.1
        args.p2 = 0.01
        args.acc_rate = 0.3
        args.temperature = 0.5

    dataset_list = ["MNIST-USPS"]
    for dataset_name in dataset_list:
        set_arg(dataset_name)
        miss_rate_list = [0.33333, 0.5, 0.66667]
        # miss_rate_list = [0.5]
        for miss_rate in miss_rate_list:
            args.missing_rate = miss_rate
            for experiment_epo in range(args.repeated_experiment_epoch):
                name = f"PE {dataset_name} miss_rate={miss_rate} p1={args.p1} p2={args.p2} acc_rate={args.acc_rate} temperature={args.temperature}"
                acc = process(name, args, experiment_epo, device)