import torch
import copy
import argparse
import os
import time
import numpy as np

from utils.result_utils import average_data
from flcore.trainmodel.models import *
import torchvision

from flcore.servers.serverfukd import FedFUKD
from flcore.servers.serverbu import FedBU
from flcore.servers.serverpgd import FedPGD
from flcore.servers.serverosd import FedOSD
from flcore.servers.servergs import FedGS
from flcore.servers.servereraser import FedEraser


def run(arg):
    time_list = []
    model_str = args.model

    

    if model_str == "cnn": # ]
        if "MNIST" in args.dataset:
            args.model = CNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
        elif "Cifar10" in args.dataset:
            args.model = CNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
        elif "Omniglot" in args.dataset:
            args.model = CNN(in_features=1, num_classes=args.num_classes, dim=33856).to(args.device)
        else:
            args.model = CNN(in_features=3, num_classes=args.num_classes, dim=10816).to(args.device)

    elif model_str == "dnn": # non-convex
        if "MNIST" in args.dataset:
            args.model = DNN(1*28*28, 100, num_classes=args.num_classes).to(args.device)
        elif "Cifar10" in args.dataset:
            args.model = DNN(3*32*32, 100, num_classes=args.num_classes).to(args.device)
        else:
            args.model = DNN(60, 20, num_classes=args.num_classes).to(args.device)
    elif model_str == "resnet": # non-convex
        if "Cifar100" in args.dataset:
            args.model = torchvision.models.resnet18(pretrained=False, num_classes=args.num_classes).to(args.device)
        elif "Cifar10" in args.dataset:
            args.model = torchvision.models.resnet18(pretrained=False, num_classes=args.num_classes).to(args.device)

    print(args.model)

    if args.algorithm == "FedEraser":
        args.head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, args.head)
        server = FedEraser(args)
    elif args.algorithm == "FedFUKD":
        args.head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, args.head)
        server = FedFUKD(args)
    elif args.algorithm == "FedBU":
        args.head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, args.head)
        server = FedBU(args)
    elif args.algorithm == "FedPGD":
        args.head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, args.head)
        server = FedPGD(args)
    elif args.algorithm == "FedOSD":
        args.head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, args.head)
        server = FedOSD(args)
    elif args.algorithm == "FUGAS":
        args.head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, args.head)
        server = FedGS(args)
            
    if(args.learning_state=="learning"):
        print(f"\n============= Training start =============")
        print("Creating server and clients ...")
        start = time.time()
        server.train()

        time_list.append(time.time()-start)

        print(f"\nAverage time cost: {round(np.average(time_list), 2)}s.")
        server.save_loss(args.global_rounds)

    elif(args.learning_state=="unlearning"):
        m = torch.cat([p.view(-1) for p in server.global_model.parameters()], dim=0)
        print(f"\n============= Unlearning start =============")
        
        server.unlearning()
        if(args.post_training_ground!=0):
            print(f"\n============= Post-training start =============")
            server.post_training(m)

    elif(args.learning_state=="retrain"):
        print(f"\n============= Retrain start =============")
        
        server.clients = [client for client in server.clients if client not in server.unlearning_clients]
        server.train()
        
    print("All done!")

    

if __name__ == "__main__":
    total_start = time.time()

    parser = argparse.ArgumentParser()

    parser.add_argument('-go', "--goal", type=str, default="test", 
                        help="The goal for this experiment")
    parser.add_argument('-dev', "--device", type=str, default="cuda",
                        choices=["cpu", "cuda"])
    parser.add_argument('-did', "--device_id", type=str, default="0")
    parser.add_argument('-data', "--dataset", type=str, default="MNIST")
    parser.add_argument('-nb', "--num_classes", type=int, default=10)
    parser.add_argument('-m', "--model", type=str, default="cnn")
    parser.add_argument('-lbs', "--batch_size", type=int, default=32)
    parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
                        help="Local learning rate")
    parser.add_argument('-ld', "--learning_rate_decay", type=bool, default=False)
    parser.add_argument('-ldg', "--learning_rate_decay_gamma", type=float, default=0.99)
    parser.add_argument('-gr', "--global_rounds", type=int, default=2000)
    parser.add_argument('-ls', "--local_epochs", type=int, default=1, 
                        help="Multiple update steps in one local epoch.")
    parser.add_argument('-algo', "--algorithm", type=str, default="FedAvg")
    parser.add_argument('-jr', "--join_ratio", type=float, default=1.0,
                        help="Ratio of clients per round")
    parser.add_argument('-rjr', "--random_join_ratio", type=bool, default=False,
                        help="Random ratio of clients per round")
    parser.add_argument('-nc', "--num_clients", type=int, default=20,
                        help="Total number of clients")
    parser.add_argument('-pv', "--prev", type=int, default=0,
                        help="Previous Running times")
    parser.add_argument('-t', "--times", type=int, default=1,
                        help="Running times")
    parser.add_argument('-eg', "--eval_gap", type=int, default=1,
                        help="Rounds gap for evaluation")
    parser.add_argument('-sfn', "--save_folder_name", type=str, default='items')
    parser.add_argument('-ab', "--auto_break", type=bool, default=False)
    parser.add_argument('-dlg', "--dlg_eval", type=bool, default=False)
    parser.add_argument('-dlgg', "--dlg_gap", type=int, default=100)
    parser.add_argument('-bnpc', "--batch_num_per_client", type=int, default=2)
    parser.add_argument('-nnc', "--num_new_clients", type=int, default=0)
    parser.add_argument('-ften', "--fine_tuning_epoch_new", type=int, default=0)
    parser.add_argument('-fd', "--feature_dim", type=int, default=512)
    parser.add_argument('-vs', "--vocab_size", type=int, default=32000, 
                        help="Set this for text tasks. 80 for Shakespeare. 32000 for AG_News and SogouNews.")
    parser.add_argument('-ml', "--max_len", type=int, default=200)
    # practical
    parser.add_argument('-cdr', "--client_drop_rate", type=float, default=0.0,
                        help="Rate for clients that train but drop out")
    parser.add_argument('-tsr', "--train_slow_rate", type=float, default=0.0,
                        help="The rate for slow clients when training locally")
    parser.add_argument('-ssr', "--send_slow_rate", type=float, default=0.0,
                        help="The rate for slow clients when sending global model")
    parser.add_argument('-ts', "--time_select", type=bool, default=False,
                        help="Whether to group and select clients at each round according to time cost")
    parser.add_argument('-tth', "--time_threthold", type=float, default=10000,
                        help="The threthold for droping slow clients")
    

    parser.add_argument("-uc","--unlearning_clients", nargs='+', type=int,default=None,
                         help='an array of integers')
    parser.add_argument("-ugr","--unlearning_ground", type=int,default=2)
    parser.add_argument("-s","--learning_state", type=str,default="learning")
    parser.add_argument("-att","--attack", type=str,default='False')
    parser.add_argument('-ulr', "--unlearning_rate", type=float, default=1e-4)
    parser.add_argument("-pgr","--post_training_ground", type=int,default=0)
    # 用于消融实验
    parser.add_argument('-con', "--contrastive", type=str, default='True')
    parser.add_argument('-gra', "--gradient_hadle", type=str, default="GEM")
    parser.add_argument('-pos', "--positive_sample", type=str, default="None")
    parser.add_argument('-neg', "--negative_sample", type=str, default="None")
    parser.add_argument('-seed', "--seed_num", type=int, default=45)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id

    if args.device == "cuda" and not torch.cuda.is_available():
        print("\ncuda is not avaiable.\n")
        args.device = "cpu"
    print("=" * 50)
    for arg in vars(args):
        print(arg, '=',getattr(args, arg))
    print("=" * 50)

    torch.manual_seed(args.seed_num)           
    torch.cuda.manual_seed(args.seed_num)      
    torch.cuda.manual_seed_all(args.seed_num)

    # with torch.profiler.profile(
    #     activities=[
    #         torch.profiler.ProfilerActivity.CPU,
    #         torch.profiler.ProfilerActivity.CUDA],
    #     profile_memory=True, 
    #     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
    #     ) as prof:
    # with torch.autograd.profiler.profile(profile_memory=True) as prof:
    run(args)
