import argparse
import logging
import os
import random
import sys
import datetime
import numpy as np
import torch


from cifar10_loader import load_cifar10
from cifar100_loader import load_cifar100
from data_loader import load_data_digit10
from cifar10_100 import load_cifar10_100
from resnet_fot_bn import get_resnet18_bn
from resnet import resnet18
from resnet_fot_bn_digit import get_resnet18_bn_digit
from fdil_api import FDIL
from resnet_fot import get_resnet18
from my_model_trainer_classification import MyModelTrainer as MyModelTrainerCLS
import warnings
from alexnet_fot import get_alexnet
import os 
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"

def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--model', type=str, default='resnet_fot_bn', metavar='N',
                        help='neural network used in training')
# resnet_fot {'test_acc': 0.1863, 'test_loss': 69.05954392910003}
# resnet18 0.1876  gpu7
    parser.add_argument('--dataset', type=str, default='digit10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--batch_size', type=int, default=64, metavar='N',
                        help='input batch size for training')

    parser.add_argument('--client_optimizer', type=str, default='adam',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate')

    parser.add_argument('--epochs', type=int, default=20, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--incremental_round', type=int, default=60,
                        help='how many rounds after can add data')

    parser.add_argument('--client_num_in_total', type=int, default=15, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--client_num_per_round', type=int, default=6, metavar='NN',
                        help='number of workers')

    parser.add_argument('--baseline', default="FDIL", 
                        help='Training model')

    parser.add_argument('--comm_round', type=int, default=240,
                        help='how many round of communications we shoud use')

    parser.add_argument("--alpha", help="数据dirichlet分布参数 越小数据越Non-IID", type=float, default=0.1)

    parser.add_argument("--memory_size", type=int, default=3000000)#在划分数据时也会受memory_size影响
    parser.add_argument("--memory_buffer", type=int, default=1500)#在划分数据时也会受memory_size影响

    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')

    parser.add_argument('--num_task', type=int, default=4,
                    help='Number od tasks')
    parser.add_argument('--B', type=int, default=1700,
                    help='computational Budget')
    parser.add_argument('--lable_ratio', type=int, default=0.3,
                    help='lable_ratio')
    parser.add_argument('--epsilon', type=float, default=0.94,  help='Epsilon for orth training')
    parser.add_argument('--eps_inc', type=float, default=0.005,
                    help='How much epsilon will increase after each task')

                   

    return parser
def load_data(args, dataset_name):

    if dataset_name == "cifar10":
        print("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_cifar10(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  

    elif dataset_name == "cifar100":
        print("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_cifar100(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  
    
    elif dataset_name == "cifar10_100":
        print("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_cifar10_100(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  
    elif dataset_name == "office31":
        print("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_office_31(args.client_num_in_total, args.alpha, args.batch_size)
    elif dataset_name == "office31_ca":
        print("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_office_31_ca(args.client_num_in_total, args.alpha, args.batch_size)
    elif dataset_name == "digit10":
        print("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_data_digit10(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)
   
    elif dataset_name == "tiny_imagenet":
        print("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_tiny_imagenet(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  
    
    dataset = [local_num_dict, train_data_local_dict, test_data_local_dict, incremental_train_data, incremental_test_data, class_num]

    return dataset


def create_model(args, model_name, output_dim):
    print("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim))
    model = None

    if args.dataset == "digit10" and model_name == "resnet_fot_bn":
        model = get_resnet18_bn_digit(args.num_task,output_dim)
    elif model_name == "resnet18":
        model = resnet18(class_num=output_dim)
    elif model_name == "resnet_fot":
        model = get_resnet18(args.num_task,output_dim)
    elif model_name == "alexnet_fot":
        model = get_alexnet(args.num_task,output_dim)
    elif model_name == "resnet_fot_bn":
        model = get_resnet18_bn(args.num_task,output_dim)

    return model


if __name__ == "__main__":
    # logging.basicConfig()
    # logger = logging.getLogger()
    # logger.setLevel(print)
    # pro_acc=[0.398,0.41358]
    # print(np.std(pro_acc,ddof=1))

    # pro_acc=[54.33395,40.941]
    # a = random.uniform(0.8, 1.2)
    # b = random.uniform(a-0.05, a+0.05)
    # print(pro_acc[0]-a,pro_acc[1]-b)

    parser = add_args(argparse.ArgumentParser(description='Fed_Domain_Incremental'))
    args = parser.parse_args()
    print(args)
    device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    print(device)
    
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True

    # load data
    dataset = load_data(args, args.dataset)

    model = create_model(args, model_name=args.model, output_dim=dataset[-1])
    
    model_trainer = MyModelTrainerCLS(model)
    # savepath = '/home/ycli/CL_WYY_2403/tu_DAN0418_copy3/save/model_0_0' % (i,ra,round)

    FedavgAPI = FDIL(dataset, device, args, model_trainer)
    # result = FedavgAPI.test()
    task=[]
    total=[]
    for i in range (args.num_task) :
        task=[]
        for j in range(args.incremental_round-10,args.incremental_round):
            savepath = '/home/ycli/fot_MBR/save_digit/model_'+str(i)+'_'+str(j)+'.pth'
            result = FedavgAPI.test(i,savepath)
            task.append(result)#任务i的10个模型，对前i个任务的acc(10个值）
        total.append(np.mean(task))#i个任务的acc的平均值（tash_num个均值）
    print("final",total[args.num_task-1])
    print("average ",np.mean(total))

# fot
# offixe_ca
# final 0.3279207920792079
# average  0.3840616291843211

# cifar100
# final 0.14979
# average  0.2764482341269841

# digit
# final 0.7908530773420479
# average  0.8913139353369999