import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import time
import torch
import argparse
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from prepare_dataset import *
from train_server import finetune_VIT
from train_client import pretrain_MAE
from federated_learning import federatedModelWeightUpdate, aggregate_from_clients, aggregate_from_clients_both
from debug import ckpt_diff
import util.misc as misc
import copy

import model_ViT
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_
import PIL

def get_argparser():
    parser = argparse.ArgumentParser(description="Run federated learning demo")
    parser.add_argument('-ra', '--ratio', default=0, type=float,
                        help='The ratio of labelled images')
    parser.add_argument('-rd', '--rounds', default=200, type=int,
                        help='The number of rounds for federated learning')
    parser.add_argument('-tc', '--totalNum_clients', default=100, type=int,
                        help='The total number of client agents used in demo')
    parser.add_argument('-sc', '--num_of_Clients', default=5, type=int,
                        help='The number of client agents sampled in each round') 
    parser.add_argument('-le', '--num_of_local_epochs', default=10, type=int,
                        help='The number of epochs in local training') 
    parser.add_argument('-ced', '--client_encoder_depth', default=1, type=int,
                        help='The number of transformer blocks in client encoder')
    parser.add_argument('-cdd', '--client_decoder_depth', default=1, type=int,
                        help='The number of transformer blocks in client decoder')     
    parser.add_argument('-bed', '--backbone_encoder_depth', default=5, type=int,
                        help='The number of transformer blocks in backbone encoder')
    parser.add_argument('-bdd', '--backbone_decoder_depth', default=1, type=int,
                        help='The number of transformer blocks in backbone decoder') 
    parser.add_argument('-si', '--save_interval', default=100, type=int,
                        help='The number of epochs that saves the model')                
    parser.add_argument('-d', '--dataset', default=0, type=int, 
                        help='The id of dataset to use, options: 0, 1, 2, 3, 4, 5, 6, 7')   
    parser.add_argument('-sp', '--sampling', default="iid", 
                        help='Options: iid, dir')
    parser.add_argument('--alpha', default=0.1, type=float, 
                        help='The required parameter for dir sampling, which decides the statistical heterogenity')
    parser.add_argument('-p', '--phase', default="pretrain",  
                        help='specify the codes to pretrain or finetune')
    parser.add_argument('-a', '--aggregate', default=False, type=bool,  help='whether to aggregate model on server or not')
    parser.add_argument('-ft', '--finetune_load_ckpt', default="./checkpoint/federated/fed_checkpoint_cat.pth",  help='the path to checkpoint of the pretrained backbone')
    return parser
    
def main(args):
    print ("Federated Learning using FedMAE")

    since1 = time.time()

    dataset_names = ["cifar10", "cifar100", "food101", "imagenet", "mini", "imagenet32", "road_sign", "inat"]
    client_dataset = dataset_names[args.dataset]
    server_dataset = dataset_names[args.dataset]

    mean, std = find_dataset_mean_std(client_dataset)

    print("Chosen dataset is %s" % client_dataset)
    
    train_transform, val_transform = init_pretrain_transform(mean, std)

    train_dataset, _, num_classes = init_dataset(train_transform, val_transform, dataset=client_dataset)

    clientIDlist = [f'clientID_{i+1}' for i in range(args.totalNum_clients)]

    assert args.ratio >= 0 and args.ratio <= 1 or isinstance(args.ratio, int), "Specified ratio should be in the range [0, 1] or an int instance"
    
    
    if args.ratio == 0:
        unsuper_train_idxs = np.arange(len(train_dataset))
        np.random.shuffle(unsuper_train_idxs)
        unsupervised_datasets = get_unsuper_datasets(train_dataset, unsuper_train_idxs, clientIDlist, sampling=args.sampling, alpha=args.alpha, dataset=client_dataset, num_classes=num_classes)
        print("Choose %s Sampling!" % args.sampling)
    elif args.ratio == 1:
        super_train_idxs = np.arange(len(train_dataset)) 
        supervised_dataset = get_super_dataset(train_dataset, super_train_idxs)
        unsupervised_datasets = None
    else:
        super_train_idxs, unsuper_train_idxs = divide_dataset(train_dataset, args.ratio, client_dataset)
        supervised_dataset = get_super_dataset(train_dataset, super_train_idxs)
        unsupervised_datasets = get_unsuper_datasets(train_dataset, unsuper_train_idxs, clientIDlist, sampling=args.sampling, alpha=args.alpha, dataset=client_dataset, num_classes=num_classes)
        print("Choose %s Sampling!" % args.sampling)

    print("Dataset under ratio %s is created!" % args.ratio)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.cuda.empty_cache()

    model_mae_depth = args.backbone_encoder_depth

    """
    client_mae_depths = [1, 1, 1, 2, 2, 4, 4]
    client_mae_vit_indexs = {0: [0], 1: [0], 2: [0], 
                             3: [1, 2], 4: [1, 2],
                             5: [3, 4, 5, 6], 6: [3, 4, 5, 6]}
    """

    client_mae_depths = [args.client_encoder_depth for i in range(args.num_of_Clients)]
    client_mae_vit_indexs = {}
    start = 0
    for i in range(len(client_mae_depths)):
        client_mae_vit_indexs[i] = [j for j in range(start, start + client_mae_depths[i])]
        start += client_mae_depths[i] 

    if args.phase == "pretrain":

        interval = 1
        t_eps = args.rounds // interval
        r_eps = args.num_of_local_epochs

        print("Federated Learning will run for %s rounds" % args.rounds)
        for i in range(args.rounds):
            print("#"*30)
            print("Federated Learning Round %s/%s" % (i+1, args.rounds))
            print("#"*30)
            randomClientIDs = random.sample(clientIDlist, args.num_of_Clients)

            data_len_ratios = {}
            data_len_total = 0
            for clientID in randomClientIDs:
                data_len_ratios[clientID] = len(unsupervised_datasets[clientID])
                data_len_total += len(unsupervised_datasets[clientID])
            for clientID in data_len_ratios:
                data_len_ratios[clientID] = data_len_ratios[clientID] / data_len_total

            cur_rd = i + 1

            ckpts = {}
            for j in range(len(randomClientIDs)):
                clientID = randomClientIDs[j]
                if args.aggregate:
                    ckpt = pretrain_MAE(clientID, unsupervised_datasets[clientID], [args.client_encoder_depth, args.client_decoder_depth], t_eps=t_eps, r_eps=r_eps, round=i, interval=interval)   
                else:
                    client_load_ckpt_path = './checkpoint/federated/fed_checkpoint_%s.pth' % j
                    ckpt = pretrain_MAE(clientID, unsupervised_datasets[clientID], [args.client_encoder_depth, args.client_decoder_depth], t_eps=t_eps, r_eps=r_eps, round=i, interval=interval, load_path=client_load_ckpt_path, save_path=client_load_ckpt_path)
                ckpts[clientID] = ckpt

            chosen_num_of_clients = model_mae_depth
            assert chosen_num_of_clients <= args.num_of_Clients, "The number of clients cannot be greater than the number of sampling clients"
            chosen_IDs = []
            chosen_ckpts = {}
            client_losses = [(cid, ckpts[cid]['loss']) for cid in ckpts]
            client_losses = sorted(client_losses, key=lambda last : last[-1])
            client_losses = client_losses[:chosen_num_of_clients]
            for c_loss in client_losses:
                cid = c_loss[0]
                chosen_IDs.append(cid)
                chosen_ckpts[cid] = ckpts[cid]
            chosen_data_len_ratios = {}
            chosen_data_len_total = 0
            for clientID in chosen_IDs:
                chosen_data_len_ratios[clientID] = len(unsupervised_datasets[clientID])
                chosen_data_len_total += len(unsupervised_datasets[clientID])
            for clientID in chosen_data_len_ratios:
                chosen_data_len_ratios[clientID] = chosen_data_len_ratios[clientID] / chosen_data_len_total
            

            if (i + 1) % args.save_interval == 0:
                ckpt_path = './checkpoint/federated/fed_checkpoint_%srd.pth' % cur_rd

                _ = pretrain_MAE(None, None, [model_mae_depth, args.backbone_decoder_depth], t_eps=t_eps, r_eps=r_eps, round=-1, interval=interval, load_path="", save_path=ckpt_path)

                new_ckpt = aggregate_from_clients(chosen_IDs, chosen_ckpts, chosen_data_len_ratios, client_mae_vit_indexs, load_path=ckpt_path)

            if i == args.rounds - 1:
                ckpt_path_a = './checkpoint/federated/fed_checkpoint_avg.pth'
                
                fed_ckpt = federatedModelWeightUpdate(randomClientIDs, ckpts, data_len_ratios, save_path=ckpt_path_a)

                ckpt_path_c = './checkpoint/federated/fed_checkpoint_cat.pth'

                _ = pretrain_MAE(None, None, [model_mae_depth, args.backbone_decoder_depth], t_eps=t_eps, r_eps=r_eps, round=-1, interval=interval, load_path="", save_path=ckpt_path_c)

                new_ckpt = aggregate_from_clients(chosen_IDs, chosen_ckpts, chosen_data_len_ratios, client_mae_vit_indexs, load_path=ckpt_path_c)
                
            else:
                if args.aggregate:
                    fed_ckpt = federatedModelWeightUpdate(randomClientIDs, ckpts, data_len_ratios)


    else:

        mean, std = find_dataset_mean_std(server_dataset)
        if server_dataset != client_dataset:
            super_train_idxs = None
        finetune_VIT(super_train_idxs, mean, std, server_dataset, args.ratio, model_mae_depth, load_path=args.finetune_load_ckpt)

    del device

    time_elapsed1 = time.time() - since1
    print('Federated Learning complete in {:.0f}m {:.0f}s'.format(time_elapsed1 // 60, time_elapsed1 % 60))



if __name__ == '__main__':
    args = get_argparser().parse_args()
    main(args)