


import os
import copy
import numpy as np
import random
import torch

import pdb
import torch.nn as nn
from tqdm import tqdm
from options import args_parser, args_parser_cifar10
from util.update_baseline import *
from util.fedavg import *

from util.dataset import *
from model.build_model import build_model
from util.dispatch import *
from util.losses import *

np.set_printoptions(threshold=np.inf)

dataset_switch = "cifar100"

def get_acc_file_path(args):

    rootpath = './temp/'
    if not os.path.exists(rootpath):  
        os.makedirs(rootpath)
 
    if args.balanced_global:
        rootpath+='global_' 
    rootpath += 'fl'
    if args.beta > 0: 
        
        rootpath += "_LP_%.2f" % (args.beta)
    fpath =  rootpath + '_acc_{}_{}_cons_frac{}_iid{}_iter{}_ep{}_lr{}_N{}_{}_seed{}_p{}_dirichlet{}_IF{}_Loss{}.txt'.format(
        args.dataset, args.model, args.frac, args.iid, args.rounds, args.local_ep, args.lr, args.num_users, args.num_classes, args.seed, args.non_iid_prob_class, args.alpha_dirichlet, args.IF, args.loss_type)
    return fpath
   
if __name__ == '__main__':
    
    if dataset_switch == 'cifar100':
        args = args_parser()
    elif dataset_switch == 'cifar10':
        args = args_parser_cifar10()
    
    
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


    fpath = get_acc_file_path(args)
    f_acc = open(fpath,'a')
    print(fpath)

    

    
    datasetObj = myDataset(args)
    if args.balanced_global:
        dataset_train, dataset_test, dict_users, dict_localtest = datasetObj.get_balanced_dataset(datasetObj.get_args())  
    else:
        dataset_train, dataset_test, dict_users, dict_localtest = datasetObj.get_imbalanced_dataset(datasetObj.get_args())  
        
        
        
        
        
        
        
    
    
    print(len(dict_users))
    
    
    
    

    
    model = build_model(args) 
    
    
    w_glob = model.state_dict()  
    
    
    

    
    args.frac = 1
    m = max(int(args.frac * args.num_users), 1) 
    prob = [1/args.num_users for j in range(args.num_users)]

    

    
    load_dir = ""
    
    model = torch.load(load_dir + "model_499.pth").to(args.device)
    g_head = torch.load(load_dir + "g_head_499.pth").to(args.device)
    g_aux = torch.load(load_dir + "g_aux_499.pth").to(args.device)
    l_heads = []
    for i in range(args.num_users):
        l_heads.append(torch.load(load_dir +  "l_head_" + str(i) + ".pth").to(args.device))


    
    w_glob = model.state_dict()  
    w_locals = [copy.deepcopy(w_glob) for i in range(args.num_users)]
    g_auxs_intervaria = []
    epoch = 0
    for client_id in range(args.num_users):  
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[client_id])
        w_locals[client_id], g_aux_intervaria, l_heads[client_id], loss_local = local.update_weights_norm_init(net=copy.deepcopy(model).to(args.device), g_head = copy.deepcopy(g_head).to(args.device), g_aux = copy.deepcopy(g_aux).to(args.device), l_head = copy.deepcopy(l_heads[client_id]).to(args.device), seed=args.seed, net_glob=model.to(args.device), epoch=epoch)
        g_auxs_intervaria.append(g_aux_intervaria)



    acc_s2, global_3shot_acc = globaltest_feat_collapse(copy.deepcopy(model).to(args.device), g_head = copy.deepcopy(g_head).to(args.device), test_dataset = dataset_test, args = args, dataset_class = datasetObj)




    
    acc_list = []
    f1_macro_list = []
    f1_weighted_list = []
    acc_3shot_local_list = []
    for i in range(args.num_users):

        acc_local, f1_macro, f1_weighted, acc_3shot_local = localtest(copy.deepcopy(model).to(args.device), copy.deepcopy(g_auxs_intervaria[i]).to(args.device), copy.deepcopy(l_heads[i]).to(args.device), dataset_test, dataset_class = datasetObj, idxs=dict_localtest[i], user_id = i)

        acc_list.append(acc_local)
        f1_macro_list.append(f1_macro)
        f1_weighted_list.append(f1_weighted)
        acc_3shot_local_list.append(acc_3shot_local) 

  
    avg3shot_acc={"head":0, "middle":0, "tail":0}
    divisor = {"head":0, "middle":0, "tail":0}
    for i in range(len(acc_3shot_local_list)):
        avg3shot_acc["head"] += acc_3shot_local_list[i]["head"][0]
        avg3shot_acc["middle"] += acc_3shot_local_list[i]["middle"][0]
        avg3shot_acc["tail"] += acc_3shot_local_list[i]["tail"][0]
        divisor["head"] += acc_3shot_local_list[i]["head"][1]
        divisor["middle"] += acc_3shot_local_list[i]["middle"][1]
        divisor["tail"] += acc_3shot_local_list[i]["tail"][1]
    avg3shot_acc["head"] /= divisor["head"]
    avg3shot_acc["middle"] /= divisor["middle"]
    avg3shot_acc["tail"] /= divisor["tail"]

    for i in range(len(acc_3shot_local_list)):
        acclist = []
        if acc_3shot_local_list[i]["head"][1] == True:
            acclist.append(acc_3shot_local_list[i]["head"][0])
        else:
            acclist.append(0)

        if acc_3shot_local_list[i]["middle"][1] == True:
            acclist.append(acc_3shot_local_list[i]["middle"][0])
        else:
            acclist.append(0)
            
        if acc_3shot_local_list[i]["tail"][1] == True:
            acclist.append(acc_3shot_local_list[i]["tail"][0])
        else:
            acclist.append(0)
        print("3shot of client {}:head:{}, middle:{}, tail{}".format(i, acclist[0], acclist[1], acclist[2]))
    

    avg_local_acc = sum(acc_list)/len(acc_list)
    

    idxs_users = np.random.choice(range(args.num_users), m, replace=False, p=prob)
    dict_len = [len(dict_users[idx]) for idx in idxs_users]
    avg_f1_macro = Weighted_avg_f1(f1_macro_list,dict_len=dict_len)
    avg_f1_weighted = Weighted_avg_f1(f1_weighted_list,dict_len)

    rnd = 0
    print('round %d, local average test acc  %.4f \n'%(rnd, avg_local_acc))
    print('round %d, local macro average F1 score  %.4f \n'%(rnd, avg_f1_macro))
    print('round %d, local weighted average F1 score  %.4f \n'%(rnd, avg_f1_weighted))
    print('round %d, average 3shot acc: [head: %.4f, middle: %.4f, tail: %.4f] \n'%(rnd, avg3shot_acc["head"], avg3shot_acc["middle"], avg3shot_acc["tail"]))
    
    print('round %d, global 3shot acc: [head: %.4f, middle: %.4f, tail: %.4f] \n'%(rnd, global_3shot_acc["head"], global_3shot_acc["middle"], global_3shot_acc["tail"]))
    print('round %d, global test acc  %.4f \n'%(rnd, acc_s2))
    
