import os
import random
import sys
import logging

import matplotlib
from torch import nn
from torch.utils.data import Dataset, DataLoader

from models.Update_domain import DomainClientUpdate, atk_train
from models.vggmodule import vgg
import matplotlib.pyplot as plt
from PIL import Image
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
import time
from matplotlib import pyplot as plt
import torch.optim as optim
from models.backdoor import create_trigger_model
from utils.fpl import proto_aggregation
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, fmnist_iid, svhn_iid
# from utils.options import args_parser
from utils.parser import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar, Lenet5, LeNet, DigitModel
from models.Fed import FedAvg
from models.test import test_img
from utils import data_utils
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.forget_event import compute_forgetting_statistics,order_examples_of_forget,sort_examples_by_forgetting
from utils.evaluate import evaluate
from torch.utils.data import Dataset, TensorDataset, DataLoader
from utils.backdoor_process import backdoor_process, plt_img
from unlearning_methods import (
    learning,
    retrain,
    rapid_retrain,
    federaser,
    increase_loss,
    class_pruning,
    fedsalun,
    fu_dws,
)



if __name__ == '__main__':
    args = args_parser()
    args.save = args.model

    log_dir = f'./result/log/{args.dataset}/{args.target}'
    os.makedirs(log_dir, exist_ok=True)
    timestamp = int(time.time())
    ul_clients_str = '_'.join(str(i) for i in (args.unlearning_client if isinstance(args.unlearning_client, (list, tuple)) else [args.unlearning_client]))
    bd_clients_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))
    log_file = os.path.join(
        log_dir,
        f'log_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.log'
    )
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    args.freeze_layers = args.freeze_layers.split(',')
    if  args.model == 'resnet18':
        args.max_pool = 'False'

    if args.dataset == 'domain_digits':
        args.num_users = 5
        args.model = 'cnn'
    elif args.dataset == 'office-caltech10':
        args.num_users = 4
    elif args.dataset == 'DomainNet':
        args.num_users = 6
    elif args.dataset == 'PACS':
        args.num_users = 4
        args.num_classes = 7

    # Update number of users based on domain split settings
    base_domain_count = args.num_users
    datasets_name = get_dataset(args)
    args.num_users = len(datasets_name)

    # Adjust client indices based on specified backdoor domain
    if 0 <= args.bkd_domain_idx < base_domain_count:
        bkd_client_idx = args.bkd_domain_idx * args.domain_times_factor
        args.backdoor_client_idx = [bkd_client_idx]
        args.unlearning_client = [bkd_client_idx]
        
    if args.target == 'learning':
        learning(args)

    elif args.target == 'retrain':
        retrain(args)

    elif args.target == 'rapid_retrain':
        rapid_retrain(args)

    elif args.target == 'federaser':
        federaser(args)
    
    elif args.target == 'increase_loss':
        increase_loss(args)

    elif args.target == 'class_pruning':
        class_pruning(args)

    elif args.target == 'fedsalun':
        fedsalun(args)

    elif args.target == 'fu_dws':
        fu_dws(args)

    
### python3 test4unlearning_main.py --target retrain --verify backdoor --unlearning_client 0 --backdoor_client_idx 0
# python3 test4unlearning_main.py --target fedsalun --verify backdoor --unlearning_client 3 --backdoor_client_idx 3 --unlearn_lr 0.1 --mask_ratio 0.1 --fedsalun_epoch 10 --lamb 0.1
# python3 test4unlearning_main.py --target fedfreeze_layer --verify backdoor --unlearning_client 1 --backdoor_client_idx 1 --freeze_ulr neglabel --frzulr 0.01 --freeze_layers 'conv1,conv2,conv3,fc1,fc2'
# python3 test4unlearning_main.py --target fedfreeze_layer_pBN --verify backdoor --unlearning_client 0 --backdoor_client_idx 0 --freeze_ulr neglabel --frzulr 0.005 --freeze_layers ''
# python3 test4unlearning_main.py --target fedfre_test --verify backdoor --unlearning_client 0 --backdoor_client_idx 0 --freeze_ulr neglabel --frzulr 0.01 --freeze_layers 'conv1,conv2,conv3,fc1,fc2'
# python3 test4unlearning_main.py --target fedsalun_nlp --verify backdoor --unlearning_client 1 --backdoor_client_idx 1 --freeze_layers 'conv1,conv2,conv3,fc1,fc2' --dataset_fullparti False    
# python3 test4unlearning_main.py --target fedsalun_nlp --verify backdoor --unlearning_client 0 --backdoor_client_idx 0 --mask_ratio 0.1 --dataset office-caltech10 --model 'vgg16' --noise_type 'gaussian' --noise_scale 0.01
# python3 test4unlearning_main.py --target fu_dws --verify backdoor --unlearning_client 1 --backdoor_client_idx 1 --mask_ratio 0.25 --fedsalun_epoch 1
# python3 test4unlearning_main.py --target fu_dws --verify backdoor --unlearning_client 1 --backdoor_client_idx 1 --mask_ratio 0.35 --fedsalun_epoch 1 --dataset office-caltech10 --model 'vgg16'
# python3 test4unlearning_main.py --target fedlc --verify backdoor --unlearning_client 1 --backdoor_client_idx 1 --dataset office-caltech10 --model 'vgg16' --act_gate_local 0.7 --act_gate_difference 0.7

# nohup python3 test4unlearning_main.py --target fu_dws --verify backdoor --unlearning_client 2 --backdoor_client_idx 2 --dataset office-caltech10 --model 'vgg16' --mask_ratio 0.5 --diff_mask_ratio 0.5 > output.log 2>&1 &
# nohup python3 test4unlearning_main.py --target increase_loss --verify backdoor --unlearning_client 3 --backdoor_client_idx 3 --dataset office-caltech10 --model 'vgg16' --mask_ratio 0.5 --diff_mask_ratio 0.5 > increase_loss22.log 2>&1 &
# nohup python3 test4unlearning_main.py --target class_pruning --verify backdoor --unlearning_client 3 --backdoor_client_idx 3 --dataset office-caltech10 --model 'vgg16' --mask_ratio 0.5 --diff_mask_ratio 0.5 > class_pruning22.log 2>&1 &
