import utils
import models
import math
import copy
import itertools
import numpy as np
from agent import Agent
from agent_sparse import Agent as Agent_s
from options import args_parser
from aggregation import Aggregation
import torch
import random
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn.utils import parameters_to_vector
import logging
import datetime
import os
from utils import KL_between_normals,load_model
from Nets_VIB import client_model_VIB
from models import CNN_MNIST
from mask_filter import Masked_BN2d


def write_file(filename, main_accuracy, backdoor_accuaracy):
    f = open(filename, "a")
    f.write("main_task_accuracy=")
    f.write(str(main_accuracy))
    f.write('\n')
    f.write("backdoor_task_accuracy=")
    f.write(str(backdoor_accuaracy))
    f.write('\n')
    
if __name__ == '__main__':
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)
    random.seed(0)
    torch.backends.cudnn.deterministic = True
    args = args_parser()

    current_time =datetime.datetime.now().strftime('%b.%d_%H.%M.%S')
    folder_path = f'save/dataset_{args.data}_{current_time}_selction{args.agent_frac}_attacker{args.num_corrupt}'
    args.folder = folder_path
        
    try:
        os.mkdir(folder_path)
    except FileExistsError:
        logger.info('Folder already exists')

    filepath = './'+'/'+folder_path+'/log_accuracy.txt'
    filepath_selec = './'+'/'+folder_path+'/selection_client.txt'
    filepath_per = './'+'/'+folder_path+'/per_class_acc.txt'
    filepath_poi = './'+'/'+folder_path+'/poi_acc.txt'
    

    logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s")
    rootLogger = logging.getLogger()
    rootLogger.setLevel(logging.DEBUG)
    if not args.debug:
        # logPath = "logs"
        logPath = folder_path
        if args.dis_check_gradient == True:
            fileName = "NoGradientAckRatio{}_{}_Method{}_data{}_alpha{}_Epoch{}_inject{}_Agg{}_noniid{}_attack{}.pt.pt".format(
                args.num_corrupt, args.num_agents, args.method, args.data, args.alpha, args.local_ep, args.poison_frac,
                args.aggr,args.non_iid, args.attack)
        else:
            fileName = "AckRatio{}_{}_Method{}_data{}_alpha{}_Epoch{}_inject{}_Agg{}_noniid{}__attack{}.pt".format(
                args.num_corrupt, args.num_agents, args.method, args.data, args.alpha, args.local_ep, args.poison_frac,
                args.aggr, args.non_iid, args.attack,
                args.cease_poison)
        fileHandler = logging.FileHandler("{0}/{1}.log".format(logPath, fileName))
        fileHandler.setFormatter(logFormatter)
        rootLogger.addHandler(fileHandler)
    logging.info(args)

    cum_poison_acc_mean = 0

    main_task_accuracy = []
    backdoor_task_accuracy = []

    # load dataset and user groups (i.e., user to data mapping)
    train_dataset, val_dataset = utils.get_datasets(args.data)
    if args.data == "cifar100":
        num_target = 100
    else:
        num_target = 10
    val_loader = DataLoader(val_dataset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
                            pin_memory=False)
    if args.non_iid:
        user_groups = utils.distribute_data_dirichlet(train_dataset, args)    
    else:
        user_groups = utils.distribute_data(train_dataset, args, n_classes=num_target)
    
    idxs = (val_dataset.targets != args.target_class).nonzero().flatten().tolist()
    # logging.info(idxs)
    # print(f'train_dataset label is {train_dataset.targets[20326]}')
    # poison the validation dataset
    poisoned_val_set = utils.DatasetSplit(copy.deepcopy(val_dataset), idxs)
    utils.poison_dataset(poisoned_val_set.dataset, args, idxs, poison_all=True)

    poisoned_val_loader = DataLoader(poisoned_val_set, batch_size=args.bs, shuffle=False, num_workers=args.num_workers,
                                     pin_memory=False)
    
    idxs = (val_dataset.targets != args.target_class).nonzero().flatten().tolist()
    poisoned_val_set_only_x = utils.DatasetSplit(copy.deepcopy(val_dataset), idxs)
    utils.poison_dataset(poisoned_val_set_only_x.dataset, args, idxs, poison_all=True, modify_label=False)

    poisoned_val_only_x_loader = DataLoader(poisoned_val_set_only_x, batch_size=args.bs, shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=False)

    # initialize a model, and the agents
    if args.data in ['cifar10','cifar100'] and args.method == 'Grace':
        global_model = models.get_model(args.device,args.data).to(args.device)
        # global_model = models.get_model(args.device,args.data,norm_layer=Masked_BN2d)
    elif args.data in ['mnist','fmnist'] and args.method == 'Grace':
        global_model = client_model_VIB(args, args.dimZ, args.alpha_cr, args.data).to(args.device)
    elif args.data in ['cifar10','cifar100'] and args.method == 'avg':
        global_model = models.get_model_avg(args.device,args.data).to(args.device)
    elif args.data in ['mnist','fmnist'] and args.method == 'avg':
        global_model = CNN_MNIST().to(args.device)

    global_mask = {}
    neurotoxin_mask = {}
    updates_dict = {}

    n_model_params = len(parameters_to_vector([ global_model.state_dict()[name] for name in global_model.state_dict()]))
    params = {name: copy.deepcopy(global_model.state_dict()[name]) for name in global_model.state_dict()}
    agents, agent_data_sizes = [], {}
    for _id in range(0, args.num_agents):
        if args.method == "Grace":            
            agent = Agent_s(global_model, _id, args, copy.deepcopy(train_dataset), user_groups[_id])
        else:
            agent = Agent(_id, args, train_dataset, user_groups[_id])
        agent_data_sizes[_id] = agent.n_data
        agents.append(agent)
        # fileName_client_asr = folder_path
        file_client_asr = os.path.join(folder_path, 'client_{}_asr_results.txt'.format(_id))
        
            # with open(file_client_asr, 'w') as f:
            #     for item in test_acc_list:
            #         f.write("{}".format(item))
            #         f.write('\n')

        # aggregation server and the loss function
    # print(f'train_dataset label is {train_dataset.targets[20326]}')
    aggregator = Aggregation(agent_data_sizes, n_model_params, poisoned_val_loader, args, None)
    criterion = nn.CrossEntropyLoss().to(args.device)
    agent_updates_list = []
    worker_id_list = []
    agent_updates_dict = {}
    mask_aggrement = []

    acc_vec = []
    asr_vec = []
    pacc_vec = []
    per_class_vec = []

    clean_asr_vec = []
    clean_acc_vec = []
    clean_pacc_vec = []
    clean_per_class_vec = []
    

    for rnd in range(1, args.rounds + 1):
        logging.info("--------round {} ------------".format(rnd))
        # if rnd == 101:
        #     state_dict = torch.load(args.checkpoint, map_location=args.device)
        #     load_model(global_model, orig_state_dict=state_dict)

        rnd_global_params = parameters_to_vector([ copy.deepcopy(global_model.state_dict()[name]) for name in global_model.state_dict()])
        agent_updates_dict = {}
        if args.method == "Grace":
            chosen = np.random.choice(args.num_agents, math.floor(args.num_agents * 1), replace=False)
        else:
            chosen = np.random.choice(args.num_agents, math.floor(args.num_agents * args.agent_frac), replace=False)
            print("chosen clients are:{}".format(chosen))
        last = rnd == args.rounds
        
        for agent_id in chosen:
            # logging.info(torch.sum(rnd_global_params))
            global_model = global_model.to(args.device)
            # if args.sync == 'True':
            #     agents[agent_id].synchronize_with_server(global_model, w_glob_keys)

            if args.method == "Grace":
                update = agents[agent_id].local_train(global_model, criterion,rnd, neurotoxin_mask = neurotoxin_mask, updates_dict=updates_dict,last=last,server=aggregator,poi_loader=poisoned_val_loader)
            else:
                update = agents[agent_id].local_train(global_model, criterion,rnd,neurotoxin_mask=neurotoxin_mask,poi_loader=poisoned_val_loader)
            
            # client_poi = './'+'/'+folder_path+'/poi_acc_{}.txt'.format(agent_id)
            # write_file(client_poi,main_accuracy=None,backdoor_accuaracy=c_asr)

            agent_updates_dict[agent_id] = update
            utils.vector_to_model(copy.deepcopy(rnd_global_params), global_model)

        # if rnd < 6 and args.method == 'Grace':
        #     iter_aggregator_agents = {chosen[i]:agents[chosen[i]] for i in range(len(chosen))}
        #     outlier = aggregator.outlier(clients=iter_aggregator_agents)
        #     # if rnd % 2 == 0:
        #     #     sorted_outlier_score = dict(sorted(outlier.items(), key=lambda item: item[1]))
        #     # else:
        #     #     sorted_outlier_score = dict(sorted(outlier.items(), key=lambda item: item[1]),reversed=True)
        #     sorted_outlier_score = dict(sorted(outlier.items(), key=lambda item: item[1]))
            # print(f'sorted_outlier is {sorted_outlier_score}')

        if args.selection and args.method == "Grace":
            
            agent_updates_aggregate_dict = {}
            # iter_select_clients_0 = []
            
            if args.aggr == 'flame':
                iter_select_clients = aggregator.FLAME(clients=agents,rnd=rnd)
                agent_updates_aggregate_dict = agent_updates_dict
                updates_dict,neurotoxin_mask = aggregator.aggregate_updates(global_model, agent_updates_aggregate_dict,rnd,iter_client=iter_select_clients)       
            else:
                m = math.floor(args.num_agents * args.agent_frac)
                # s_0 = int(m/2)
                # s_1 = m-s_0
                # keys = list(sorted_outlier_score.keys())
                # iter_select_clients_0 = keys[:m] 
                # select_clt = 0
                # for key,val in sorted_outlier_score.items():                    
                #     if select_clt < m:
                #         select_clt +=1
                #         iter_select_clients_0.append(key)
                # print(f'outlier clients: {iter_select_clients_0}')
                # logging.info(f'| outlier clients: {iter_select_clients_0}|')

                iter_select_clients_1 = aggregator.select(clients=agents,iter=rnd,m=m)
                # if rnd < 6:
                #     iter_select_clients = iter_select_clients_0
                # else:
                #     iter_select_clients = iter_select_clients_1
                # iter_select_clients = list(set(iter_select_clients_0) & set(iter_select_clients_1))
                iter_select_clients = iter_select_clients_1
                # print(f'select clients: {iter_select_clients}')
                logging.info(f'| select clients: {iter_select_clients}|')
                write_file(filepath_selec,main_accuracy=iter_select_clients,backdoor_accuaracy=0)
                
                for i in range(len(agent_updates_dict)):
                    for j in range(len(iter_select_clients)):
                        if i == iter_select_clients[j]:
                            agent_updates_aggregate_dict[i]=agent_updates_dict[i] 
                            # agent_select[]              
                updates_dict,neurotoxin_mask = aggregator.aggregate_updates(global_model, agent_updates_aggregate_dict,rnd)                       
        else:
            updates_dict,neurotoxin_mask = aggregator.aggregate_updates(global_model, agent_updates_dict,rnd)        

        # elif args.method == "Grace":
        #     agent_updates_aggregate_dict = {}
        #     for i in range(len(agent_updates_dict)):
        #         for j in range(len(chosen)):
        #             if i == chosen[j]:
        #                 agent_updates_aggregate_dict[i]=agent_updates_dict[i]

        # if args.method == "Grace":
        #     if args.selection:
        #         updates_dict,neurotoxin_mask = aggregator.aggregate_updates(global_model, agent_updates_aggregate_dict,rnd)       
        #     else:
        #             updates_dict,neurotoxin_mask = aggregator.aggregate_updates(global_model, agent_updates_aggregate_dict,rnd)
        # else:
        #     updates_dict,neurotoxin_mask = aggregator.aggregate_updates(global_model, agent_updates_dict,rnd)
        worker_id_list.append(agent_id + 1)

        if args.selection and args.method == "Grace":
            iter_aggregator_agents = {iter_select_clients[i]:agents[iter_select_clients[i]] for i in range(len(iter_select_clients))}
        elif args.method == "Grace":
            iter_aggregator_agents = {chosen[i]:agents[chosen[i]] for i in range(len(chosen))}
    
        # iter_aggregator_agents = {chosen[i]:agents[chosen[i]] for i in range(len(chosen))}

        if args.method == 'Grace':
            aggregator.global_POE(clients=iter_aggregator_agents)
        
        # print(f'train_dataset label is {train_dataset.targets[20326]}')
        
        if rnd % args.snap == 0:
            if args.method == 'Grace':
                val_loss, (val_acc, val_per_class_acc), _ = utils.get_loss_n_accuracy(global_model, criterion, val_loader,
                                                                                    args, rnd, num_target)
            else:
                val_loss, (val_acc, val_per_class_acc), _ = utils.get_accuracy_avg(global_model, criterion, val_loader,
                                                                                    args, rnd, num_target)
            logging.info(f'| Val_Loss/Val_Acc: {val_loss:.3f} / {val_acc:.3f} |')
            logging.info(f'| Val_Per_Class_Acc: {val_per_class_acc} ')
            acc_vec.append(val_acc)
            per_class_vec.append(val_per_class_acc)
            write_file(filename=filepath_per,main_accuracy=val_per_class_acc,backdoor_accuaracy=None)

            main_task_accuracy.append(val_acc)

            if args.method == 'Grace':
                poison_loss, (asr, _), fail_samples = utils.get_loss_n_accuracy(global_model, criterion,
                                                                                poisoned_val_loader, args, rnd, num_target)
            else:
                poison_loss, (asr, _), fail_samples = utils.get_accuracy_avg(global_model, criterion,
                                                                                poisoned_val_loader, args, rnd, num_target)

            cum_poison_acc_mean += asr
            asr_vec.append(asr)
            logging.info(f'| Attack Loss/Attack Success Ratio: {poison_loss:.3f} / {asr:.3f} |')

            backdoor_task_accuracy.append(asr)
            if args.method == 'Grace':
                poison_loss, (poison_acc, _), fail_samples = utils.get_loss_n_accuracy(global_model, criterion,
                                                                                    poisoned_val_only_x_loader, args,rnd, num_target)
            else:
                poison_loss, (poison_acc, _), fail_samples = utils.get_accuracy_avg(global_model, criterion,
                                                                                    poisoned_val_only_x_loader, args,
                                                                        rnd, num_target)
            pacc_vec.append(poison_acc)
            logging.info(f'| Poison Loss/Poison accuracy: {poison_loss:.3f} / {poison_acc:.3f} |')
            write_file(filename=filepath_poi,main_accuracy=poison_acc,backdoor_accuaracy=None)

        save_frequency = 50
        if rnd % save_frequency == 0:
            save_path = "{}/{}round_model.pt".format(folder_path,rnd)
            torch.save(global_model.state_dict(),save_path)
    
    write_file(filename=filepath,main_accuracy=main_task_accuracy,backdoor_accuaracy=backdoor_task_accuracy)
    logging.info('Training has finished!')
