import argparse
import torch
from Dassl.dassl.utils import set_random_seed
from Dassl.dassl.config import get_cfg_default
from Dassl.dassl.engine import build_trainer

import os
import math
import copy
import pickle
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

def extend_cfg(cfg, args):
    """
    Add new config variables.
    """
    from yacs.config import CfgNode as CN

    cfg.DEVICE = args.gpu

    # Factorization param
    cfg.FACTORIZATION = args.factorization
    cfg.RANK = args.rank

    # Differential privacy param
    cfg.NORM_THRESH = args.norm_thresh
    cfg.NOISE = args.noise

    # Config for DP_FPL
    cfg.TRAINER.NAME = 'DP_FPL'
    cfg.TRAINER.DP_FPL = CN()
    cfg.TRAINER.DP_FPL.N_CTX = args.n_ctx  # number of context vectors
    cfg.TRAINER.DP_FPL.PREC = "fp32"  # fp16, fp32, amp
    cfg.TRAINER.DP_FPL.CLASS_TOKEN_POSITION = "end"  # 'middle' or 'end' or 'front'

    cfg.DATASET.ROOT = args.root # dataset path
    cfg.DATASET.USERS = args.num_users # number of clients
    cfg.DATASET.IID = args.iid  # is iid
    cfg.DATASET.USEALL = args.useall # use all data for training instead of few shot
    cfg.DATASET.NUM_SHOTS = args.num_shots # caltech101, dtd, oxford_flowers, oxford_pets, food101
    cfg.DATASET.PARTITION = args.partition # cifar10, cifar100
    cfg.DATASET.BETA = args.beta # cifar10, cifar100
    cfg.DATALOADER.TRAIN_X.N_DOMAIN = 6 if args.num_users == 6 else 4 # domainnet, office
    if args.useall:
        cfg.DATALOADER.TRAIN_X.BATCH_SIZE = args.train_batch_size
    else:
        cfg.DATALOADER.TRAIN_X.BATCH_SIZE = args.num_shots
    cfg.DATALOADER.TEST.BATCH_SIZE = args.test_batch_size

    cfg.OPTIM.ROUND = args.round # global round
    cfg.OPTIM.MAX_EPOCH = args.local_round # local epoch
    cfg.OPTIM.LR = args.lr # learning rate

    cfg.MODEL.BACKBONE.PRETRAINED = True

    cfg.SEED = args.seed


def setup_cfg(args):
    cfg = get_cfg_default() # arguments list, type yacs.config.CfgNode _C from defaults.py
    extend_cfg(cfg, args) # add more arguments


    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file) # load dataset

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file) # load model

    cfg.freeze()

    return cfg

# def setup_cfg2(args):
#     cfg = get_cfg_default() # arguments list, type yacs.config.CfgNode _C from defaults.py
#     extend_cfg(cfg, args) # add more arguments
#
#
#     # 1. From the dataset config file
#     if args.dataset_config_file:
#         cfg.merge_from_file(args.dataset_config_file2) # load another dataset
#
#     # 2. From the method config file
#     if args.config_file:
#         cfg.merge_from_file(args.config_file) # load model
#
#     cfg.freeze()
#
#     return cfg


def save_checkpoint(args, epoch, local_weights, local_acc, neighbor_acc, file_dir):
    # dataset = args.dataset_config_file.split('/')[-1].split('.')[0]
    save_filename = file_dir + '/{}_{}_{}_{}.pth.tar'.format(args.factorization, args.rank, args.noise, args.seed)
    state = {
        "epoch": epoch + 1,
        "local_weights": local_weights,
        "local_acc": local_acc,
        "neighbor_acc": neighbor_acc,
    }
    torch.save(state, save_filename)

def load_checkpoint(args):
    dataset = args.dataset_config_file.split('/')[-1].split('.')[0]
    save_filename = os.path.join(os.getcwd(), f'/checkpoints/hou_{dataset}/{args.factorization}_{args.rank}_{args.noise}_{args.seed}.pth.tar')
    if not os.path.exists(save_filename):
        return 0, [{} for i in range(args.num_users)], [], []
    checkpoint = torch.load(save_filename, map_location=torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu"))
    epoch = checkpoint["epoch"]
    local_weights = checkpoint["local_weights"]
    local_acc = checkpoint["local_acc"]
    neighbor_acc = checkpoint["neighbor_acc"]
    return epoch, local_weights, local_acc, neighbor_acc


def main(args):
    cfg = setup_cfg(args)
    # cfg2 = setup_cfg2(args)

    if cfg.SEED >= 0:
        set_random_seed(cfg.SEED)

    print("GPU", args.gpu, "is used")
    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = True

    dirichlet = False
    if args.dataset_config_file.split('/')[-1].split('.')[0] in ['cifar10', 'cifar100']:
        dirichlet = True

    global_gradients = [{} for i in range(args.num_users)] # 2 *
    local_gradients = [{} for i in range(args.num_users)]
    local_weights = [{} for i in range(args.num_users)]
    local_weights_g = [[] for i in range(args.num_users)]
    local_weights_l = [[] for i in range(args.num_users)]
    local_weights_u = [[] for i in range(args.num_users)]
    local_weights_v = [[] for i in range(args.num_users)]


    local_trainer = build_trainer(cfg)
    # local_trainer2 = build_trainer(cfg2)
    initial_weights = copy.deepcopy(local_trainer.model.state_dict())
    # initial_weights2 = copy.deepcopy(local_trainer2.model.state_dict())

    # # initialize local trainer sequence
    # local_trainer_list = []
    # for idx in range(0, cfg.DATASET.USERS):
    #     local_trainer_list.append(copy.deepcopy(local_trainer))
    # for idx in range(0, cfg.DATASET.USERS):
    #     local_trainer_list.append(copy.deepcopy(local_trainer2))

    # Training
    start_epoch = 0
    max_epoch = cfg.OPTIM.ROUND
    local_acc_list, neighbor_acc_list, = [], []
    if args.resume == 'True':
        start_epoch, local_weights, local_acc_list, neighbor_acc_list = load_checkpoint(args)
        print('Resume from epoch', start_epoch)
    if start_epoch == max_epoch - 1:
        return
    for epoch in range(start_epoch, max_epoch): # global communication loop
        idxs_users = list(range(0, cfg.DATASET.USERS))
        print("------------local train start epoch:", epoch, "-------------")

        # create data iters
        data_iters = []
        for idx in idxs_users:
            # local_trainer = local_trainer_list[idx]
            local_trainer.set_model_mode("train")
            loader = local_trainer.fed_train_loader_x_dict[idx]
            data_iters.append(iter(loader))
        max_batch = len(loader)

        if epoch < 5 or args.factorization != 'secfpp':
        # if epoch > -1: # skip clustering
            cluster_number = 1
            clusters = [0 for _ in range(args.num_users)]
        else:   # clustering main
            if (epoch+1) % 10 == 0:
                rank_k = args.rank
                # cluster_number = 1
                # clusters = [0 for _ in range(2 * args.num_users)]
                adaptive_score = 1.0 # initial
                while adaptive_score > 0.9:
                    try_cluster_number = cluster_number + 1 # increase cluster number, try it to see cluster+1 is available
                    Xs = []
                    for idx in range(args.num_users):
                        # do dimensional reduction locally
                        std_scaler = StandardScaler()
                        pca = PCA(n_components=rank_k)
                        local_weight = local_weights[idx]
                        local_prompt_embedding = copy.deepcopy(local_weight['prompt_learner.local_ctx']) # local_ctx
                        # local_prompt_embedding = copy.deepcopy(local_weight['prompt_learner.local_ctx'] + local_weight['prompt_learner.global_ctx'])
                        local_prompt_embedding = local_prompt_embedding.detach().cpu().numpy()
                        # print('local_prompt_embedding shape', local_prompt_embedding.shape)
                        scaled_local_prompt_embedding = std_scaler.fit_transform(local_prompt_embedding)
                        reduced_local_prompt_embedding = pca.fit_transform(scaled_local_prompt_embedding)
                        # print('reduced_local_prompt_embedding shape', reduced_local_prompt_embedding.shape)
                        Xs.append(reduced_local_prompt_embedding.reshape(-1))

                    X = np.transpose(np.stack(Xs, axis=1))
                    kmeans = KMeans(n_clusters=try_cluster_number, random_state=42) # cluster over clients
                    X_fit = kmeans.fit_predict(X)
                    # print(X.shape, X_fit.shape)
                    adaptive_score = silhouette_score(X, X_fit)
                    print("silhouette_score:", adaptive_score)
                    if adaptive_score > 0.9:
                        cluster_number = try_cluster_number
                        clusters = kmeans.labels_
                    print("cluster_number", cluster_number, "; cluster result:", clusters)

        # cluster_number = 2 # dummy cluster results
        # clusters = [1]*10 + [0]*10 # dummy cluster results

        cluster_group = dict.fromkeys(list(range(cluster_number)))
        for key in cluster_group.keys():
            cluster_group[key] = [] # generate distinct empty lists
        for idx in range(args.num_users):
            cluster_group[clusters[idx]].append(idx)
        print("cluster_group", cluster_group)

        ## clustering finished

        # loop through batches
        for batch in range(0, max_batch):
            local_trainer.set_model_mode("train")
            for idx in range(args.num_users):
                if epoch == 0:
                    local_trainer.model.load_state_dict(initial_weights, strict=False)
                else:
                    local_trainer.model.load_state_dict(local_weights[idx], strict=False)
                # local_trainer = local_trainer_list[idx]

                ### train forward
                local_trainer.train_forward(idx=idx, train_iter=data_iters[idx])

                # gradient
                global_gradients[idx] = local_trainer.model.prompt_learner.global_ctx.grad.data
                if args.factorization in ['fedotp', 'dplora', 'dpfpl', 'secfpp']:
                    local_gradients[idx] = local_trainer.model.prompt_learner.local_ctx.grad.data

                # weight
                local_weight = local_trainer.model.state_dict()
                local_weights_g[idx] = copy.deepcopy(local_weight['prompt_learner.global_ctx'])
                if args.factorization in ['fedotp', 'dplora', 'dpfpl', 'secfpp']:
                    local_weights_l[idx] = copy.deepcopy(local_weight['prompt_learner.local_ctx'])
                if args.factorization in ['fedpgp', 'dplora', 'dpfpl']:
                    local_weights_u[idx] = copy.deepcopy(local_weight['prompt_learner.local_u_ctx'])
                    local_weights_v[idx] = copy.deepcopy(local_weight['prompt_learner.local_v_ctx'])

            # average local gradient via cluster
            gradient_in_cluster = dict.fromkeys(list(range(cluster_number)))
            for key in cluster_group.keys():
                avg_cluster_gradient = sum([global_gradients[idx] for idx in cluster_group[key]]) / len(cluster_group[key])
                gradient_in_cluster[key] = avg_cluster_gradient

            # # # average global gradient
            # avg_global_gradient = sum(global_gradients) / cfg.DATASET.USERS
            # avg_local_gradient = sum(local_gradients) / cfg.DATASET.USERS

            # backward and update
            for idx in range(args.num_users):
                local_weights[idx]['prompt_learner.global_ctx'] = local_weights_g[idx]
                if args.factorization in ['fedotp', 'dplora', 'dpfpl', 'secfpp']:
                    local_weights[idx]['prompt_learner.local_ctx'] = local_weights_l[idx]
                if args.factorization in ['fedpgp', 'dplora', 'dpfpl']:
                    local_weights[idx]['prompt_learner.local_u_ctx'] = local_weights_u[idx]
                    local_weights[idx]['prompt_learner.local_v_ctx'] = local_weights_v[idx]

                local_trainer.model.load_state_dict(local_weights[idx], strict=False)
                # local_trainer = local_trainer_list[idx]

                ### train backward
                cluster_idx = clusters[idx]
                # if epoch < 0.8 * max_epoch: # early stop for global
                local_trainer.train_backward_global(avg_global_gradient=gradient_in_cluster[cluster_idx])
                if args.factorization in ['fedotp', 'dplora', 'dpfpl', 'secfpp']:
                    # if epoch > 0.5 * max_epoch: # start midway for local
                    local_trainer.train_backward_local(local_gradient=local_gradients[idx])

                local_weight = local_trainer.model.state_dict()
                # local_weights[idx] = local_weight
                local_weights_g[idx] = copy.deepcopy(local_weight['prompt_learner.global_ctx'])
                if args.factorization in ['fedotp', 'dplora', 'dpfpl', 'secfpp']:
                    local_weights_l[idx] = copy.deepcopy(local_weight['prompt_learner.local_ctx'])
                if args.factorization in ['fedpgp', 'dplora', 'dpfpl']:
                    local_weights_u[idx] = copy.deepcopy(local_weight['prompt_learner.local_u_ctx'])
                    local_weights_v[idx] = copy.deepcopy(local_weight['prompt_learner.local_v_ctx'])

        # test
        if (epoch+1) % 10 == 0:
            print("------------local test start-------------")
            local_trainer.set_model_mode("eval")
            results_local, results_neighbor = [], []
            for idx in idxs_users:
                local_weights[idx]['prompt_learner.global_ctx'] = local_weights_g[idx]
                if args.factorization in ['fedotp', 'dplora', 'dpfpl', 'secfpp']:
                    local_weights[idx]['prompt_learner.local_ctx'] = local_weights_l[idx]
                if args.factorization in ['fedpgp', 'dplora', 'dpfpl']:
                    local_weights[idx]['prompt_learner.local_u_ctx'] = local_weights_u[idx]
                    local_weights[idx]['prompt_learner.local_v_ctx'] = local_weights_v[idx]

                local_trainer.model.load_state_dict(local_weights[idx], strict=False)
                results_local.append(local_trainer.test(idx=idx, split='local'))

                # if not dirichlet:
                #     results_neighbor.append(local_trainer.test(idx=idx, split='neighbor'))

            print("results_local:", results_local)
            local_acc, neighbor_acc = [], []
            for k in range(len(results_local)):
                local_acc.append(results_local[k][0])
            #     if not dirichlet:
            #         neighbor_acc.append(results_neighbor[k][0])
            local_acc_list.append(sum(local_acc)/len(local_acc))
            print(f"Global test local acc:", sum(local_acc)/len(local_acc))
            # if not dirichlet:
            #     neighbor_acc_list.append(sum(neighbor_acc)/len(neighbor_acc))
            #     print(f"Global test neighbor acc:", sum(neighbor_acc)/len(neighbor_acc))
            print("------------local test finish-------------")
            print(f"Epoch: {epoch}/{max_epoch}\tfinished batch : {batch}/{max_batch}")

            # save checkpoint
            dataset = args.dataset_config_file.split('/')[-1].split('.')[0]
            my_file1 = os.path.join(os.getcwd(), f'checkpoints/hou_{dataset}')
            my_file2 = os.path.join(os.getcwd(), f'outputs/hou_{dataset}')
            if not os.path.isdir(my_file1):
                os.mkdir(my_file1)
            if not os.path.isdir(my_file2):
                os.mkdir(my_file2)
            save_checkpoint(args, epoch, local_weights, local_acc_list, neighbor_acc_list, my_file1)
            pickle.dump([local_acc_list, neighbor_acc_list],
                        open(os.path.join(os.getcwd(), f'outputs/hou_{dataset}/acc_{args.factorization}_{args.rank}_{args.noise}_{args.seed}.pkl'), 'wb'))

    print("maximum test local acc:", max(local_acc_list))
    print("mean of local acc:",np.mean(local_acc_list[-5:]))
    # if not dirichlet:
    #     print("maximum test neighbor acc:", max(neighbor_acc_list))
    #     print("mean of neighbor acc:",np.mean(neighbor_acc_list[-5:]))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--round', type=int, default=100, help="number of communication round")
    parser.add_argument('--num-users', type=int, default=10, help="number of users")
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--train-batch-size', type=int, default=32, help="number of trainer batch size")
    parser.add_argument('--test-batch-size', type=int, default=100, help="number of test batch size")
    parser.add_argument("--seed", type=int, default=1, help="only positive value enables a fixed seed")

    # parameters of factorization and differential privacy
    parser.add_argument('--factorization', type=str, default='dpfpl', help='Choose from: promptfl, fedotp, fedpgp, dplora, dpfpl')
    parser.add_argument('--rank', type=int, default=8, help='matrix factorization rank')
    parser.add_argument('--norm-thresh', type=float, default=10.0, help='clipping norm threshold')
    parser.add_argument('--noise', type=float, default=0.0, help='differential privacy noise scale')

    # parameters of datasets
    # caltech101, oxford_flowers, oxford_pets, food101 and dtd
    parser.add_argument('--iid', default=False, help="is iid, control the iid of caltech101, oxford_flowers, oxford_pets, food101 and dtd")
    parser.add_argument('--num-shots', type=int, default=16, help="number of shots in few shot setting")
    parser.add_argument('--useall', default=True, help="is useall, True for all training samples, False for few shot learning")
    # cifar10, cifar100
    parser.add_argument('--partition', type=str, default='noniid-labeldir', help='the data partitioning strategy of cifar10 and cifar100, select from "homo, noniid-labeluni, noniid-labeldir,noniid-labeldir100"')
    parser.add_argument('--beta', type=float, default=0.3, help='The parameter for the dirichlet distribution for data partitioning')
    parser.add_argument('--local_round', type=int, default=10, help="number of local training round")

    # parameters of learnable prompts
    parser.add_argument('--n_ctx', type=int, default=16, help="number of text encoder of text prompts")

    # parameters of path
    parser.add_argument("--root", type=str, default="/datasets", help="path to dataset")
    parser.add_argument("--config-file", type=str, default="configs/trainers/DP-FPL/vit_b16.yaml", help="path to config file")
    parser.add_argument("--dataset-config-file", type=str, default="configs/datasets/cifar100.yaml", help="path to config file for dataset setup")
    parser.add_argument("--resume", type=str, default="False", help="resume training or not")

    # new args
    parser.add_argument('--gpu', default='0', type=str, help='which gpu the code runs on')

    args = parser.parse_args()
    main(args)

