import sys
import os
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
from utils.config import img_param_init, set_random_seed
from utils.prepare_data_dg_clip import *
import copy
import argparse
from nets.models import ClipModelat
import torch.optim as optim
import torch
import numpy as np
from utils.training import train
from utils.testing import test
from utils.aggregation import communication
from utils.loss_function import get_lossfn
from adaptation import LMMDLoss

from tqdm import tqdm

from distr import edic
from adaptation import AdversarialLoss

from distr.utils_main import get_discr_domain, get_gen_domain, get_frame_domain

class ParamGroupsCollector:
    def __init__(self, lr):
        self.reset(lr)

    def reset(self, lr):
        self.lr = lr
        self.param_groups = []

    def collect_params(self, *models):
        for model in models:
            if hasattr(model, 'parameter_groups'):
                groups_inc = list(model.parameter_groups())
                for grp in groups_inc:
                    if 'lr_ratio' in grp:
                        grp['lr'] = self.lr * grp['lr_ratio']
                    elif 'lr' not in grp: # Do not overwrite existing lr assignments
                        grp['lr'] = self.lr
                self.param_groups += groups_inc
            else:
                self.param_groups += [
                        {'params': model.parameters(), 'lr': self.lr} ]

def model_average(client_net_states, client_weights):
    state_avg = copy.deepcopy(client_net_states[0])
    client_weights = [w / sum(client_weights) for w in client_weights]

    for k in state_avg.keys():
        state_avg[k] = torch.zeros_like(state_avg[k])
        for i, w in enumerate(client_weights):
            state_avg[k] = state_avg[k] + client_net_states[i][k] * w

    return state_avg

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='BrainTumor')
    parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
    parser.add_argument('--clr', type=float, default=5e-5, help='learning rate')
    parser.add_argument('--datapercent', type=float,
                        default=6e-1, help='data percent to use')
    parser.add_argument('--batch', type=int, default=32, help='batch size')
    parser.add_argument('--root_dir', type=str, default='./data/')
    parser.add_argument('--iters', type=int, default=50,
                        help='iterations for communication')
    parser.add_argument('--wk_iters', type=int, default=1,
                        help='optimization iters in local worker between communication')
    parser.add_argument('--mode', type=str, default='FedAtImg')
    parser.add_argument('--net', type=str, default='ViT-B/32',
                        help='[RN50 | RN101 | RN50x4 | RN50x16 | RN50x64 | ViT-B/32 | ViT-B/16 | ViT-L/14 | ViT-L/14@336px]')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--n_clients', type=int, default=20)
    parser.add_argument('--n_iter', type=int, default=200)
    parser.add_argument('--test_envs', type=int, nargs='+', default=[3]) # Global client [0-N]
    parser.add_argument('--beta1', type=float, default=0.9)
    parser.add_argument('--beta2', type=float, default=0.98)
    parser.add_argument('--eps', type=float, default=1e-6)
    parser.add_argument('--step', type=float, default=0)
    parser.add_argument('--aggmode', type=str, default='att') # att or avg
    parser.add_argument('--weight_decay', type=float, default=0.02)
    parser.add_argument('--method', type=str, default='ours') # ours, fedclip, fedavg, moon, fedprox, etc.
    parser.add_argument('--temp', type=float, default=0.5)
    parser.add_argument('--factor', type=float, default=0.01)
    parser.add_argument('--exp_weight', type=float, default=1.0)


    parser.add_argument('--use_clips', action='store_true', default=False)
    parser.add_argument('--use_clipz', action='store_true', default=False)
    parser.add_argument('--use_clipc', action='store_true', default=False)
    parser.add_argument('--is_alpha', action='store_true', default=False)
    parser.add_argument('--use_do', action='store_true', default=False)
    parser.add_argument('--a_p_num', type=int, default=3)
    parser.add_argument('--r_p_num', type=int, default=3)
    parser.add_argument('--alpha', type=float, default=1.2, help='parameter of dirichlet distribution')
    parser.add_argument('--K', type=int, default=5, help='client num')
    args = parser.parse_args()
    args.random_state = np.random.RandomState(1)
    set_random_seed(args.seed)
    if args.is_alpha:
        args.n_clients = 12
    else:
        args.n_clients = 4

    args = img_param_init(args)
    print(args)
    os.makedirs('./data/', exist_ok=True)
    if args.method == 'fedprox' or args.method == 'fedavg' or args.method == 'moon':
        server_model = ClipModelat(
            args.net, device=device, attention=True, freezepy=False, method=args.method)
    else:
        server_model = ClipModelat(
            args.net, device=device, attention=True, freezepy=True, method=args.method)
    server_discr_model = None
    archtype = 'mlp'
    dc_vars = edic(locals())
    dc_vars['dim_z'] = 512
    dc_vars['dim_s'] = 512
    dc_vars['dim_c'] = 512
    dc_vars['dim_x'] = 512
    dc_vars['dim_y'] = args.num_classes
    dc_vars['std_c1s_val'] = 0.3 # 1.0

    dc_vars['pstd_s'] = 0.3
    dc_vars['pstd_z'] = 0.3
    dc_vars['pstd_x'] = 0.3
    dc_vars['actv'] = 'Sigmoid'
    # dc_vars['actv'] = 'ReLU'
    dc_vars['use_clip'] = False
    dc_vars['use_do'] = False
    server_discr_model = get_discr_domain(archtype, dc_vars)
    server_discr_model.to(device)
    server_gen_model = get_gen_domain(archtype, dc_vars, server_discr_model)
    server_gen_model.to(device)
    server_frame_model = get_frame_domain(server_discr_model, server_gen_model, dc_vars, device)

    if args.is_alpha:
        print("by alpha loading")
        train_loaders, val_loaders, test_loaders, test_train, train_test_loaders = getfeadataloader_by_alpha(args, server_model)
    else:
        train_loaders, val_loaders, test_loaders, test_train, train_test_loaders = get_data(
                args.dataset)(args, server_model)
    server_model.initdgatal(test_loaders[3])
    # For multi client
    client_num = len(test_loaders)
    _test_envs = ''.join([str(e) for e in args.test_envs])
    if _test_envs == '01':
        args.test_envs = [0, 1]
    elif _test_envs == '02':
        args.test_envs = [0, 6]
    elif _test_envs == '03':
        if args.K == 3:
            args.test_envs = [0, 7]
        elif args.K == 5:
            args.test_envs = [0, 11]
        elif args.K == 10:
            args.test_envs = [0, 21]
        elif args.K == 25:
            args.test_envs = [0, 51]
    elif _test_envs == '12':
        args.test_envs = [5, 6]
    elif _test_envs == '13':
        args.test_envs = [5, 11]
    elif _test_envs == '23':
        args.test_envs = [10, 11]
    sclient_num = client_num-len(args.test_envs)
    client_weights = [float(1 / sclient_num) for i in range(sclient_num)]

    client_weights+= [float(1 / sclient_num) for i in range(len(args.test_envs))]
    print(client_weights)
    models = [copy.deepcopy(server_model)for idx in range(client_num)]
    lossfns = [None for idx in range(client_num)]
    models_discr = [None for idx in range(client_num)]
    models_gen = [None for idx in range(client_num)]
    models_frame = [None for idx in range(client_num)]
    if args.method == 'ours':
        models_discr = [get_discr_domain(archtype, dc_vars).to(device) for idx in range(client_num)]
        models_gen = [get_gen_domain(archtype, dc_vars, server_discr_model).to(device) for idx in range(client_num)]
        models_frame = [get_frame_domain(models_discr[idx], models_gen[idx], dc_vars, device) for idx in range(client_num)]

        lossfns = [get_lossfn(models_discr[idx], models_frame[idx], dc_vars['dim_y']) for idx in range(client_num)]
    for i in range(client_num):
        models[i].model.to(device)
        models[i].fea_attn.to(device)
    best_changed = False
    server_model.to(device)
    server_model_pre = copy.deepcopy(server_model)
    server_model_pre.to(device)
    best_acc = [0 for idx in range(0, client_num)]
    finalrecord = ''
    logrecord = ''
    log = [[] for idx in range(0, client_num)]
    previous_nets = [copy.deepcopy(models[idx].to(device)) for idx in range(0, client_num)]
    adv = [AdversarialLoss(device=device) for idx in range(0, client_num)]
    if args.aggmode == 'att':
        if args.method == 'ours':
            optimizers = []
            optimizers_causal = []
            for idx in range(client_num):
                pgc = ParamGroupsCollector(args.lr)
                pgc.collect_params(models_discr[idx])
                pgc.collect_params(models_gen[idx], models_frame[idx])
                optimizers_causal.append(optim.Adam(params=pgc.param_groups, lr=args.clr, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay))
                optimizers.append(optim.Adam(params=[{'params': models[idx].fea_attn.parameters()}], lr=args.lr, betas=(
                    args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay))
    if args.aggmode == 'avg':
        optimizers_causal = [None for idx in range(client_num)]
        optimizers = [optim.Adam(params=[{'params': models[idx].parameters()}], lr=args.lr, betas=(
            args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay) for idx in range(client_num)]
    for a_iter in tqdm(range(args.iters), colour='blue'): # All Epoch
        client_discr_net_states = []
        client_gen_net_states = []
        client_frame_net_states = []
        client_clip_ada_net_states = []
        for wi in range(args.wk_iters): #Each client local training epoch
            print("============ Train epoch {} ============".format(wi + a_iter * args.wk_iters))
            logrecord += 'Train epoch:%d\n' % (wi + a_iter * args.wk_iters)
            for client_idx, model in enumerate(models):
                if client_idx in args.test_envs:
                    pass
                else: # Client i finish training procedure
                    if args.dataset == 'BrainTumor':
                        train(args, model, train_loaders[client_idx], optimizers[client_idx], device, test_train[0],
                              adv[client_idx], server_model_pre, previous_nets[client_idx], lossfn=lossfns[client_idx],
                            discr=models_discr[client_idx], gen=models_gen[client_idx], frame=models_frame[client_idx], optimizer_causal=optimizers_causal[client_idx])

                    elif args.dataset == 'RealSkin' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
                        train(
                            args, model, train_test_loaders[client_idx], optimizers[client_idx], device, test_train[0], adv[client_idx], server_model_pre, previous_nets[client_idx], lossfn=lossfns[client_idx],
                            discr=models_discr[client_idx], gen=models_gen[client_idx], frame=models_frame[client_idx], optimizer_causal=optimizers_causal[client_idx])
                    args.step += 1

        if args.method == 'ours':
            for client_idx, model in enumerate(models):
                if client_idx not in args.test_envs:
                    client_discr_net_states.append(models_discr[client_idx].state_dict())
                    client_gen_net_states.append(models_gen[client_idx].state_dict())
                    client_frame_net_states.append(models_frame[client_idx].state_dict())


        with torch.no_grad():
            server_model_pre = copy.deepcopy(server_model.to(device))
            previous_nets = [copy.deepcopy(models[idx].to(device)) for idx in range(0, client_num)]
            server_model, models = communication(
                args, server_model, models, client_weights)
            if args.method == 'ours':
                global_discr_net_state = model_average(client_discr_net_states, client_weights[:sclient_num])
                global_gen_net_state = model_average(client_gen_net_states, client_weights[:sclient_num])
                server_discr_model.load_state_dict(global_discr_net_state)
                server_gen_model.load_state_dict(global_gen_net_state)
                for client_idx in range(client_num):
                    models_discr[client_idx].load_state_dict(global_discr_net_state)
                    models_gen[client_idx].load_state_dict(global_gen_net_state)

            for client_idx in range(client_num):
                if client_idx in args.test_envs:
                    test_acc, bacc, f1, net_benefits, tpr, fpr,\
                        net_benefits_all, net_benefits_none = test(args, server_model,
                                    test_loaders[client_idx], device, server_discr_model)
                    # server_model.to('cpu')
                    log[client_idx].append([test_acc, bacc, f1])
                    print(
                        ' Test site-{:d}| Test Acc: {:.4f} | Bacc: {:.4f} | F1: {:.4f}'.format(client_idx, test_acc, bacc, f1))
                    if test_acc > best_acc[client_idx]:
                        best_acc[client_idx] = test_acc
                        net_benefits = np.array(net_benefits, dtype=float)
                        print('Test site-{}'.format(client_idx), 'net_benefits:', net_benefits)
                        net_benefits_all = np.array(net_benefits_all, dtype=float)
                        print('Test site-{}'.format(client_idx), 'net_benefits_all:', net_benefits_all)
                        net_benefits_none = np.array(net_benefits_none, dtype=float)
                        print('Test site-{}'.format(client_idx), 'net_benefits_none:', net_benefits_none)
                        print('Test site-{}'.format(client_idx), 'tpr:', tpr)
                        print('Test site-{}'.format(client_idx), 'fpr:', fpr)
                        if args.method == 'ours':
                            checkpoint = {'backbone': server_model.state_dict(),
                                          'discr_model': global_discr_net_state,
                                          'gem_model': global_gen_net_state}
                        else:
                            checkpoint = {'backbone': server_model.state_dict()}
                        save_path = os.path.join(f'./save_model/{args.dataset}/{args.test_envs}_{args.is_alpha}/{args.method}')
                        os.makedirs(save_path, exist_ok=True)
                        torch.save(checkpoint, os.path.join(save_path, f'{args.test_envs}_best_acc_lr{args.lr}_clr{args.clr}.pth'))
                else:
                    test_acc, bacc, f1, net_benefits, tpr, fpr, \
                        net_benefits_all, net_benefits_none = test(args, server_model, test_loaders[client_idx], device,
                                                                   server_discr_model)
                    log[client_idx].append([test_acc, bacc, f1])
                    print(
                        ' Test site-{:d}| Test Acc: {:.4f} | Bacc: {:.4f} | F1: {:.4f}'\
                            .format(client_idx, test_acc, bacc, f1))

                    if test_acc > best_acc[client_idx]:
                        best_acc[client_idx] = test_acc
                        net_benefits = np.array(net_benefits, dtype=float)
                        print('Test site-{}'.format(client_idx), 'net_benefits:', net_benefits)
                        net_benefits_all = np.array(net_benefits_all, dtype=float)
                        print('Test site-{}'.format(client_idx), 'net_benefits_all:', net_benefits_all)
                        np.savetxt('./results/Client'+str(client_idx)+'BrainTumorNetBenefitALL.csv',
                                   net_benefits_all, delimiter=',', fmt='%.6f')
                        net_benefits_none = np.array(net_benefits_none, dtype=float)

                        print('Test site-{}'.format(client_idx), 'net_benefits_none:', net_benefits_none)
                        print('Test site-{}'.format(client_idx), 'tpr:', tpr)
                        print('Test site-{}'.format(client_idx), 'fpr:', fpr)
