import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import time 

from src.data import *
from src.models import *
from src.client import *
from src.utils import *
from src.utils.utils import set_seed
from src.nlp import dataset as nlp_dataset, model as nlp_model


def main_fedet(args):

    path = args.path

    print(' ')
    print('\n'.join(f'{k}={v}' for k, v in vars(args).items()))
    ##################################### Data partitioning section
    print('-'*40)
    print('Getting Clients Data')
    
    num_users = args.num_users
    archs = args.models
    fracs = args.fracs
    num_clusters = len(num_users)
    
    print(f'Num Users: {num_users}')
    print(f'archs: {archs}')
    print(f'fracs: {fracs}')
    print(f'Data ratios: {args.data_ratios}')
    
    if args.nlp: 
        public_ds = nlp_dataset.get_public_ds(args)
        train_ds_global, val_ds_global, test_ds_global = nlp_dataset.get_train_val_test_ds(args)
        
        test_dl_global = DataLoader(dataset=test_ds_global, batch_size=128, shuffle=False, drop_last=False)

        num_labels = nlp_dataset.hf_ds_config[args.dataset]['num_labels']

        partitions_train, partitions_test, partitions_train_stat, partitions_test_stat = \
            get_partitions_customD(train_ds_global, test_ds_global, args, num_labels)

    else: 
        public_ds = get_distill_data(args.distill_dataset, args.datadir, args.alg, args.dataset, num_clusters)

        train_ds_global, test_ds_global, train_dl_global, test_dl_global \
        = get_dataset_global(args.dataset, args.datadir, batch_size=128,
                                            p_train=1.0, p_test=1.0, seed=args.seed)
        
        partitions_train, partitions_test, partitions_train_stat, partitions_test_stat = \
            get_partitions_customD(train_ds_global, test_ds_global, args)
    

    Y_train = np.array(train_ds_global.target) if not args.nlp else np.array(train_ds_global['labels'])
    
    print('-'*40)
    ################################### build model
    print('-'*40)
    print('Building models for clients')
    
    users_model = []
    net_glob = []
    initial_state_dict = []
    for num, arch in zip(num_users, archs):
        users_model_tmp, net_glob_tmp, initial_state_dict_tmp = \
            get_models_fedmh(num_users=num, model=arch, dataset=args.dataset, args=args) \
            if not args.nlp else \
            nlp_model.get_models(num, arch, num_labels, args.datadir)
        users_model.append(users_model_tmp)
        net_glob.append(net_glob_tmp)
        initial_state_dict.append(initial_state_dict_tmp)
    
    if not args.nlp: 
        for cn, models in enumerate(users_model): 
            for mod in models: 
                mod.load_state_dict(initial_state_dict[cn])
        
        for cn, mod in enumerate(net_glob): 
            mod.load_state_dict(initial_state_dict[cn])

    num_params_list = []
    for cn, mod in enumerate(net_glob): 
        print(f'Model No {cn+1}')
        total = 0
        for name, param in mod.named_parameters():
            #print(name, param.size())
            total += np.prod(param.size())
        num_params_list.append(total)
        print(f'total params {total}')
        print('-'*40)
        
    ################################# Initializing Clients
    print('-'*40)
    print('Initializing Clients')
    clients = []
    for cn in range(len(num_users)):
        print(f'---- Clients Group # {cn+1}')
        clients_tmp = []
        for idx in range(num_users[cn]):
            sys.stdout.flush()
            print(f'-- Client {idx}, Train Stat {partitions_train_stat[cn][idx]} Test Stat {partitions_test_stat[cn][idx]}')

            noise_level=0
            dataidxs = partitions_train[cn][idx]
            dataidxs_test = partitions_test[cn][idx]
            
            unq, unq_cnt = np.unique(Y_train[dataidxs], return_counts=True)
            tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
            print(f'Actual Partition Stat: {tmp}')
            
            if args.nlp: 
                train_ds_local = train_ds_global.select(dataidxs)
                test_ds_local  = test_ds_global.select(dataidxs_test)
                
                bs = min(args.local_bs, len(train_ds_local))
                
                train_dl_local = DataLoader(dataset=train_ds_local, batch_size=bs, shuffle=True, drop_last=False,
                                            num_workers=4, pin_memory=True)
                test_dl_local = DataLoader(dataset=test_ds_local, batch_size=64, shuffle=False, drop_last=False, 
                                           num_workers=4, pin_memory=True)
                
            else: 
                train_ds_local = torch.utils.data.Subset(train_ds_global, dataidxs)
                test_ds_local  = torch.utils.data.Subset(test_ds_global, dataidxs_test)
                
                bs = min(args.local_bs, len(train_ds_local))
                def custom_collate(batch):
                    batch = list(filter(lambda x: x is not None, batch))  # Remove None elements
                    if len(batch) > 1:
                        return torch.utils.data.dataloader.default_collate(batch)
                    else:
                        return None
        
                train_dl_local = DataLoader(dataset=train_ds_local, batch_size=bs, shuffle=True, drop_last=False,
                                            collate_fn=custom_collate, num_workers=4, pin_memory=True)
                test_dl_local = DataLoader(dataset=test_ds_local, batch_size=64, shuffle=False, drop_last=False, 
                                            collate_fn=custom_collate, num_workers=4, pin_memory=True)

            optim = args.optim
            
            clients_tmp.append(Client_FedDFMH(idx, users_model[cn][idx], args.local_bs, args.local_ep[cn], optim,
                    args.lr[cn], args.momentum, args.local_wd[cn], args.lr_scheduler[cn], args.device, 
                    train_dl_local, test_dl_local, args.nlp))
        
        clients.append(clients_tmp)
    print('-'*40)
    ###################################### Federation
    print('Starting FL')
    print('-'*40)
    start = time.time()
    
    loss_train = []
    w_locals, loss_locals = [], []
    glob_acc_wavg = [[] for _ in range(len(num_users))]
    glob_acc_kd = [[] for _ in range(len(num_users))]

    w_glob = copy.deepcopy(initial_state_dict)

    for iteration in range(args.rounds):
        iter_start_time = time.time()
        idxs_users=[]
        for cn in range(len(num_users)):
            m = max(int(fracs[cn] * num_users[cn]), 1)
            idxs_users_tmp = np.random.choice(range(num_users[cn]), m, replace=False)
            idxs_users.append(idxs_users_tmp)
        
        print(f'----- ROUND {iteration+1} -----')
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        sys.stdout.flush()
        s_time = time.time()
        for cn in range(len(idxs_users)):
            for idx in idxs_users[cn]:
                clients[cn][idx].set_state_dict(copy.deepcopy(w_glob[cn]))

                loss = clients[cn][idx].train()
                loss_locals.append(copy.deepcopy(loss))
        
        e_time = time.time()
        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        template = '-- Average Train loss {:.3f}, Time: {:.2f}'
        print(template.format(loss_avg, e_time-s_time))
        ####### FedAvg ####### START
        for cn in range(len(num_users)):
            total_data_points = sum([len(partitions_train[cn][r]) for r in idxs_users[cn]])
            fed_avg_freqs = [len(partitions_train[cn][r]) / total_data_points for r in idxs_users[cn]]
            w_locals = []
            for idx in idxs_users[cn]:
                w_locals.append(copy.deepcopy(clients[cn][idx].get_state_dict()))

            ww = AvgWeights(w_locals, weight_avg=fed_avg_freqs)
            w_glob[cn] = copy.deepcopy(ww)
            net_glob[cn].load_state_dict(copy.deepcopy(ww))

            acc = eval_test(net_glob[cn], args, test_dl_global)
            
            glob_acc_wavg[cn].append(acc)
        ####### FedAvg ####### END
        
        ###### Logits Avg #######
        s_time = time.time()
        logits_locals = []
        indd=0
        for cn in range(len(num_users)):
            for idx in idxs_users[cn]:
                logits_locals.append(np.array(clients[cn][idx].inference(public_ds).cpu()))
                indd+=1

        varss = np.array([np.var(mat, axis=1) for mat in logits_locals])
        vars_normalized = varss/np.mean(varss, axis=0)
        weighted_logits = np.sum([vars_normalized[i][:, np.newaxis] * logits_locals[i] for i in range(indd)], axis=0)
        teacher_labels = np.argmax(weighted_logits, axis=1)
        
        public_ds.set_labels(teacher_labels)
        
        teacher_logits = []
        for kk in range(logits_locals[0].shape[0]):
            temp_logits = []
            for jj in range(indd):
                if np.argmax(logits_locals[jj][kk]) != teacher_labels[kk]:
                    temp_logits.append(logits_locals[jj][kk])
            if len(temp_logits) >= 1:
                temp_var = np.var(temp_logits, axis=1)
                temp_var_norm = temp_var/np.mean(temp_var, axis=0)
                temp_weighted_logits = np.sum([temp_var_norm[i]*temp_logits[i] for i in range(len(temp_logits))], axis=0)
            elif len(temp_logits) == 0:
                temp_weighted_logits = np.zeros(num_labels)
            teacher_logits.append(temp_weighted_logits)
        
        teacher_logits = np.array(teacher_logits)
        public_ds.set_logits(teacher_logits)
        
        e_time = time.time()
        print(f'Inference & Logit Avg done: {e_time-s_time:.2f}')
        ###### Logits Avg #######
        
        ##### Global Model KD #####
        for cn in range(len(num_users)):
            net_glob[cn].load_state_dict(copy.deepcopy(w_glob[cn]))
            net_glob[cn].to(args.device)
            net_glob[cn].train()
            
            public_dl = torch.utils.data.DataLoader(public_ds, batch_size=128, shuffle=True, drop_last=False,
                                                    num_workers=4, pin_memory=True)
            optimizer = torch.optim.Adam(net_glob[cn].parameters(), lr=args.distill_lr, weight_decay=args.distill_wd[cn])

            ce_criterion = nn.CrossEntropyLoss()
            kl_criterion = nn.KLDivLoss(reduction="batchmean")
            T = args.distill_T

            for _ in range(args.distill_E):
                batch_loss = []
                for batch_idx, d2 in enumerate(public_dl):
                    net_glob[cn].zero_grad()

                    teacher_x, teacher_y, teacher_logits, teacher_labels = d2
                    if not args.nlp: 
                        teacher_x, teacher_logits = teacher_x.to(args.device), teacher_logits.to(args.device)
                        teacher_labels = teacher_labels.to(args.device)
                        logits_student = net_glob[cn](teacher_x)
                    else: 
                        teacher_x = {k:v.to(args.device) for k,v in teacher_x.items() if k != 'labels'}
                        teacher_logits = teacher_logits.to(args.device)
                        teacher_labels = teacher_labels.to(args.device)
                        logits_student = net_glob[cn](**teacher_x).logits

                    ce_loss = ce_criterion(logits_student, teacher_labels)
                    kd_loss = kl_criterion(F.log_softmax(logits_student/T, dim=1), F.softmax(teacher_logits/T, dim=1))
                    
                    loss = ce_loss + 0.1 * (T**2) * kd_loss
                    #loss.requires_grad = True
                    optimizer.zero_grad()
                    loss.backward() #retain_graph=True)
                    optimizer.step()
                #scheduler.step()
                
            w_glob[cn] = copy.deepcopy(net_glob[cn].state_dict())
            acc_kd = eval_test(net_glob[cn], args, test_dl_global)
            net_glob[cn] = net_glob[cn].cpu()
            glob_acc_kd[cn].append(acc_kd)
        ##### Global Model KD #####
        print("-- Results:")
        for cn in range(len(num_users)):
            template = "--[Cluster {:.1f}]: {}, Global Acc KD: {:.2f}, Best: {:.2f}"
            print(template.format(cn, archs[cn], glob_acc_kd[cn][-1], np.max(glob_acc_kd[cn])))

        loss_train.append(loss_avg)
        loss_locals.clear()

        iter_end_time = time.time()

        print(f'Round {iteration+1} Time: {(iter_end_time - iter_start_time)/60:.1f} mins')


    end = time.time()
    duration = end-start
    print('-'*40)
    ############################### FedAvg Final Results
    print('-'*40)
    print('FINAL RESULTS')
    for cn in range(len(num_users)):
        template = "-- Global Acc KD: {:.2f}"
        print(template.format(glob_acc_kd[cn][-1]))

        template = "-- Global Acc Avg Final 10 Rounds: {:.2f}"
        print(template.format(np.mean(glob_acc_kd[cn][-10:])))

        template = "-- Global Best Acc: {:.2f}"
        print(template.format(np.max(glob_acc_kd[cn])))

        #template = ("-- Avg Local Acc: {:3.2f}")
        #print(template.format(np.mean(temp_acc[cn])))
#
        #template = ("-- Avg Best Local Acc: {:3.2f}")
        #print(template.format(np.mean(temp_best_acc[cn])))

        print(f'-- FL Time: {duration/60:.2f} minutes')
    print('-'*40)
    
    final_glob = []
    avg_final_glob = []
    best_glob = []
    for cn in range(len(num_users)):
        final_glob.append(glob_acc_kd[cn][-1])
        kk = int(num_users[cn]*fracs[cn])
        avg_final_glob.append(np.mean(glob_acc_kd[cn][-kk:]))
        best_glob.append(np.max(glob_acc_kd[cn]))

    return (glob_acc_wavg, glob_acc_kd, final_glob, avg_final_glob, best_glob, duration)

def run_fedet(args, fname):
    alg_name = 'FedET'

    seeds = [2023, 2024, 2022, 2021]
    
    exp_glob_acc_wavg = []
    exp_glob_acc_kd = []
    exp_final_glob=[]
    exp_avg_final_glob=[]
    exp_best_glob=[]
    #exp_avg_final_local=[]
    #exp_avg_best_local=[]
    exp_fl_time=[]

    for trial in range(args.ntrials):
        print('*'*40)
        print(' '*20, alg_name)
        print(' '*20, 'Trial %d'%(trial+1))

        # set_seed(seeds[trial])
        glob_acc_wavg, glob_acc_kd, final_glob, avg_final_glob, best_glob, \
        duration = main_fedet(args)

        exp_glob_acc_wavg.append(glob_acc_wavg)
        exp_glob_acc_kd.append(glob_acc_kd)
        exp_final_glob.append(final_glob)
        exp_avg_final_glob.append(avg_final_glob)
        exp_best_glob.append(best_glob)
        exp_fl_time.append(duration/60)

        print('*'*40)
        print(' '*20, 'End of Trial %d'%(trial+1))
        print(' '*20, 'Final Results')

        template = "-- Global Final Acc: {}"
        r = [float(f'{item:.2f}') for item in exp_final_glob[-1]]
        print(template.format(r))

        template = "-- Global Avg Final 10 Rounds Acc : {}"
        r = [float(f'{item:.2f}') for item in exp_avg_final_glob[-1]]
        print(template.format(r))

        template = "-- Global Best Acc: {}"
        r = [float(f'{item:.2f}') for item in exp_best_glob[-1]]
        print(template.format(r))

        print(f'-- FL Time: {exp_fl_time[-1]:.2f} minutes')

    print('*'*40)
    print(' '*20, alg_name)
    print(' '*20, 'Avg %d Trial Results'%args.ntrials)

    template = "-- Global Final Acc: {} +- {}"
    r1 = [float(f'{item:.2f}') for item in np.mean(exp_final_glob, axis=0)]
    r2 = [float(f'{item:.2f}') for item in np.std(exp_final_glob, axis=0)]
    print(template.format(r1, r2))

    template = "-- Global Avg Final 10 Rounds Acc: {} +- {}"
    r1 = [float(f'{item:.2f}') for item in np.mean(exp_avg_final_glob, axis=0)]
    r2 = [float(f'{item:.2f}') for item in np.std(exp_avg_final_glob, axis=0)]
    print(template.format(r1, r2))

    template = "-- Global Best Acc: {} +- {}"
    r1 = [float(f'{item:.2f}') for item in np.mean(exp_best_glob, axis=0)]
    r2 = [float(f'{item:.2f}') for item in np.std(exp_best_glob, axis=0)]
    print(template.format(r1, r2))

    print(f'-- FL Time: {np.mean(exp_fl_time):.2f} minutes')

    with open(fname+'_results_summary.txt', 'a') as text_file:
        print('*'*40, file=text_file)
        print(' '*20, alg_name, file=text_file)
        print(' '*20, 'Avg %d Trial Results'%args.ntrials, file=text_file)

        template = "-- Global Final Acc: {} +- {}"
        r1 = [float(f'{item:.2f}') for item in np.mean(exp_final_glob, axis=0)]
        r2 = [float(f'{item:.2f}') for item in np.std(exp_final_glob, axis=0)]
        print(template.format(r1, r2), file=text_file)

        template = "-- Global Avg Final 10 Rounds Acc: {} +- {}"
        r1 = [float(f'{item:.2f}') for item in np.mean(exp_avg_final_glob, axis=0)]
        r2 = [float(f'{item:.2f}') for item in np.std(exp_avg_final_glob, axis=0)]
        print(template.format(r1, r2), file=text_file)

        template = "-- Global Best Acc: {} +- {}"
        r1 = [float(f'{item:.2f}') for item in np.mean(exp_best_glob, axis=0)]
        r2 = [float(f'{item:.2f}') for item in np.std(exp_best_glob, axis=0)]
        print(template.format(r1, r2), file=text_file)

        print(f'-- FL Time: {np.mean(exp_fl_time):.2f} minutes', file=text_file)

        print('*'*40)
    
    print('Saving Global Accuracy')
    np.save(fname+'_glob_acc_wavg.npy', np.array(exp_glob_acc_wavg))
    np.save(fname+'_glob_acc_kd.npy', np.array(exp_glob_acc_kd))
    return