import argparse
import logging
import os
import random
import sys
import datetime
import numpy as np
import torch


from data_loader import load_data_digit10
from office_loader import load_office_home
from office31_loader import load_office_31
from domainnet_loader import load_domainnet
from cifar10_loader import load_cifar10
from cifar100_loader import load_cifar100
from tiny_imagenet_loader import load_tiny_imagenet

from nflows.flows.base import Flow
from nflows.transforms.permutations import RandomPermutation, ReversePermutation
from nflows.transforms.base import CompositeTransform
from nflows.transforms.coupling import AffineCouplingTransform
from nflows.nn.nets.myresnet import ResidualNet
from torch.nn import functional as F
from nflows.distributions.normal import StandardNormal

from model import Resnet18_plus
from fdil_api import FDIL

from my_model_trainer_classification import MyModelTrainer as MyModelTrainerCLS
import warnings

warnings.filterwarnings('ignore')

def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--model', type=str, default='resnet56', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training')

    parser.add_argument('--client_optimizer', type=str, default='adam',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate')

    parser.add_argument('--lr_p', type=float, default=0.0001, metavar='LR',
                        help='learning rate for persoanl model')

    parser.add_argument('--epochs', type=int, default=5, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--epochs_personal', type=int, default=5, metavar='EP',
                        help='how many epochs will be trained locally for personal model')

    parser.add_argument('--incremental_round', type=int, default=100,
                        help='how many rounds after can add data')

    parser.add_argument('--client_num_in_total', type=int, default=10, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--incremental_stage', type=int, default=3, metavar='NN',
                        help='number of incremental tasks')

    parser.add_argument('--client_num_per_round', type=int, default=10, metavar='NN',
                        help='number of workers')

    parser.add_argument('--baseline', default="our_method", 
                        help='Training model')

    parser.add_argument('--comm_round', type=int, default=10,
                        help='how many round of communications we shoud use')

    parser.add_argument("--alpha", help="数据dirichlet分布参数 越小数据越Non-IID", type=float, default=0.1)

    parser.add_argument("--memory_size", type=int, default=2000)

    parser.add_argument("--lambda_p", type=float, default=0.5)

    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')

    return parser


def load_data(args, dataset_name):

    if dataset_name == "digit10":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_data_digit10(args.client_num_in_total, args.alpha, args.batch_size, args.incremental_stage)

    elif dataset_name == "office-home":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_office_home(args.client_num_in_total, args.alpha, args.batch_size)

    elif dataset_name == "office31":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_office_31(args.client_num_in_total, args.alpha, args.batch_size)

    elif dataset_name == "domainnet":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_domainnet(args.client_num_in_total, args.alpha, args.batch_size)        

    elif dataset_name == "cifar10":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_cifar10(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  

    elif dataset_name == "cifar100":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_cifar100(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  

    elif dataset_name == "tiny_imagenet":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        
        local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num = load_tiny_imagenet(args.client_num_in_total, args.alpha, args.batch_size,args.memory_size)  

    dataset = [local_num_dict, train_data_local_dict, test_data_local_dict, incremental_train_data, incremental_test_data, class_num]

    return dataset


def get_1d_nflow_model(feature_dim, hidden_feature, context_feature, num_layers):
    transforms = []
    for l in range(num_layers):
        assert num_layers//2>1
        if l < num_layers//2:
            transforms.append(ReversePermutation(features=feature_dim))
        else:
            transforms.append(RandomPermutation(features=feature_dim))
            
        mask = (torch.arange(0, feature_dim)>=(feature_dim//2)).float()
           
        net_func = lambda in_d, out_d: ResidualNet(in_features=in_d, out_features=out_d,
                                                hidden_features=hidden_feature, context_features=context_feature,
                                                num_blocks=2, activation=F.leaky_relu, dropout_probability=0)
        transforms.append(AffineCouplingTransform(mask=mask, transform_net_create_fn=net_func))
        
    transform = CompositeTransform(transforms)
    base_dist = StandardNormal(shape=[feature_dim])
    flow = Flow(transform, base_dist)
    return flow

def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim))
    model = None
    
    model_c = Resnet18_plus(32, xa_dim=512, num_classes=output_dim)
    model_nf = get_1d_nflow_model(feature_dim=512, hidden_feature=512, context_feature=output_dim,
                                                num_layers=4)
    model = (model_c, model_nf)

    return model


if __name__ == "__main__":
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    parser = add_args(argparse.ArgumentParser(description='Fed_Domain_Incremental'))
    args = parser.parse_args()
    logger.info(args)
    device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    logger.info(device)
    
    # set seed (default 0)
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True

    # load data
    dataset = load_data(args, args.dataset)

    model = create_model(args, model_name=args.model, output_dim=dataset[-1])
    
    model_trainer = MyModelTrainerCLS(model,args)

    if args.baseline == "FDIL":
        fdilAPI = FDIL(dataset, device, args, model_trainer)
        fdilAPI.train()

