import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import time
import torch
import argparse
import random
import copy
import numpy as np
import prepare_dataset as pdata
import distributed as dist
from train_worker import pretrain_MAE, pretrain_CNN
from evaluation import finetune_VIT, finetune_CNN, train_VIT, train_CNN
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
import util.misc as misc
import baseline_models as bm
import util.debug as debug
import json
import math

def get_argparser():
    parser = argparse.ArgumentParser(description="Arguements of decentralized masked autoencoder")
    parser.add_argument('-rd', '--rounds', default=25, type=int,
                        help='The number of rounds for federated learning')
    parser.add_argument('-nw', '--num_total_workers', default=20, type=int,
                        help='The total number of workers in the network')
    parser.add_argument('-le', '--num_of_local_epochs', default=2, type=int,
                        help='The number of epochs in local training') 
    
    parser.add_argument('-sp', '--save_path', default='./vit_ckpt_study/', type=str, help='checkpoint save path')

    parser.add_argument('-d', '--dataset', default=4, type=int, 
                        help='The id of dataset to use: 0 - CIFAR10; 1 - CIFAR100; 2 - Food101; 3 - ImageNet; 4 - Mini-ImageNet; 5 - Road-Sign, 6 - Mini-INAT;')
    parser.add_argument('-ra', '--ratio', default=1.0, type=float,
                        help='The ratio of labelled images')
    parser.add_argument('-samp', '--sampling', default="iid", 
                        help='The way of samping, iid or dir')
    parser.add_argument('--alpha', default=1e-1, type=float, 
                        help='The required parameter for dir sampling, which decides the statistical heterogenity')
    parser.add_argument('-ri', '--record_interval', default=100, type=int, 
                        help='The interval of saving checkpoint')
    parser.add_argument('-agg', '--agg', default=3, type=int, 
                        help='The mode of aggregation: 0 - average weights; 1 - data volume weights; 2 - round weights; 3 - our weights')
    parser.add_argument('-bl', '--baseline', default=0, type=int, 
                        help='The baselines: 0 - ViT; 1 - ResNet')
    
    # Model Size Study
    parser.add_argument('-sc', '--scenario', default=0, type=int, 
                        help='The scenarios: 0 - Centralized; 1 - Federated')
    parser.add_argument('-dp', '--depths_range', default=1, type=int, 
                        help='The number of model size options')
    parser.add_argument('-scp', '--start_model_ckpt_path', default='./checkpoint_start/', type=str, help='The save path of start model checkpoints (Used for ensuring same starting weights)')
    parser.add_argument('-logp', '--logs_save_path', default='./performance_study_logs/', type=str, help='The save path of training logs')

    parser.add_argument('-p', '--phase', default="gap_study",  
                        help='specify the codes to: gap_study, gap_bridge')
    parser.add_argument('-scn', '--start_client_number', default=2, type=int,  
                        help='The starting number of clients')
    parser.add_argument('-eid', '--experiment_id', default=0, type=int,  
                        help='The id of experiment: 0 - study-n; 1 - study-d; 2 - bridge-n; 3 - bridge-m; 4 - bridge-d; 5 - bridge-T')
    parser.add_argument('--seed', default=0, type=int,  
                        help='The seed used for reproducibility')
    return parser

    
def main(args):
    print ("Run Codes for Generalization Perfomance Study")

    since = time.time()
    since1 = time.time()
    model_time = {}

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if os.path.exists(args.save_path):
        os.makedirs('%sworker/' % (args.save_path), exist_ok=True)
        os.makedirs('%sglobal/' % (args.save_path), exist_ok=True)
        os.makedirs('%sfinetune/' % (args.save_path), exist_ok=True)

    if not os.path.exists(args.logs_save_path):
        os.makedirs('%s' % (args.logs_save_path), exist_ok=True)

    if not os.path.exists(args.start_model_ckpt_path):
        os.makedirs('%sglobal/' % (args.start_model_ckpt_path), exist_ok=True)
        print('create')

    dataset_names = ["cifar10", "cifar100", "food101", "imagenet", "mini", "road_sign", "inat"]
    dataset = dataset_names[args.dataset]

    mean, std = pdata.find_dataset_mean_std(dataset)
    print("Chosen dataset is %s" % dataset)
    
    # initialize image transform
    train_transform, val_transform = pdata.init_pretrain_transform(mean, std)

    # initialize dataset
    train_dataset, val_dataset, num_classes = pdata.init_dataset(train_transform, val_transform, dataset=dataset)

    if dataset == "mini":
        # random split mini-imagenet dataset for generating training and validation datasets
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        parent_dataset = copy.deepcopy(train_dataset)
        train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])


    # initialize IDs of workers/clients
    workerIDlist = [f'workerID_{i+1}' for i in range(args.num_total_workers)]
    for wid in workerIDlist:
        os.makedirs('%sworker/%s/' % (args.save_path, wid), exist_ok=True)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.cuda.empty_cache()

    ratios = [args.ratio] 
    num_T = args.rounds
    this_r_eps = args.num_of_local_epochs

    if args.baseline == 0:
        # create depths for ViT
        depths_list = []
        for i in range(args.depths_range):
            depths_list.append([i+1, 1, 768])
    else:
        # create depths for ResNet
        if args.depths_range == 1:
            depths_list = [18]
        else:
            depths_list = [2+16*i for i in range(args.depths_range)]

    for depths in depths_list:
        if args.baseline == 0:
            ckpt_path = '%sglobal/start_model_%sblock_%sembed_%s.pth' % (args.start_model_ckpt_path, depths[0], depths[2], dataset)
            if not (os.path.exists(ckpt_path)):
                _ = train_VIT(args, None, None, depths, 1, num_classes, load_path=None, save_path=ckpt_path)
        else:
            ckpt_path = '%sglobal/start_CNN_%sdepth_%s.pth' % (args.start_model_ckpt_path, depths, dataset)
            if not (os.path.exists(ckpt_path)):
                _ = train_CNN(args, None, None, depths, 1, num_classes, load_path=None, save_path=ckpt_path)

    if args.phase == "gap_study":
        # study how the performance gap changed by training settings
        
        # Data split following IID distribution
        super_train_idxs_list = {}
        each_client_data_idxs = {}
        
        for ratio in ratios:
            if dataset == "mini":
                create_cen_fed_scenarios(super_train_idxs_list, each_client_data_idxs, ratio, dataset, train_dataset, workerIDlist, parent_dataset)
            else:
                create_cen_fed_scenarios(super_train_idxs_list, each_client_data_idxs, ratio, dataset, train_dataset, workerIDlist)
        
        if args.scenario == 0:

            cen_record = {}
            ctr = 0
            
            # Centralized  Training
            if args.baseline == 0:
                # Train ViT
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_model_%sblock_%sembed_%s.pth' % (args.start_model_ckpt_path, depths[0], depths[2], dataset)
                        save_ckpt_path = '%sglobal/cen_model_%sblock_%sembed.pth' % (args.save_path, depths[0], depths[2])

                        super_train_idxs = super_train_idxs_list[ratio]
                        centralized_vit_training(args, cen_record, depths, load_ckpt_path, save_ckpt_path, train_dataset, val_dataset, super_train_idxs, num_T, this_r_eps, num_classes, ctr)

                        ctr += 1

            else:
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_CNN_%sdepth_%s.pth' % (args.start_model_ckpt_path, depths, dataset)
                        save_ckpt_path = '%sglobal/cen_CNN_%sdepth.pth' % (args.save_path, depths)

                        super_train_idxs = super_train_idxs_list[ratio]
                        centralized_cnn_training(args, cen_record, depths, load_ckpt_path, save_ckpt_path, train_dataset, val_dataset, super_train_idxs, num_T, this_r_eps, num_classes, ctr)
                                                
                        ctr += 1
                
            # Record the training statistics
            record_string_to_number(cen_record)
            save_training_log(args, dataset, cen_record)
        
        else:
            # Federated Training
            decen_record = {}
            ctr = 0
             # Compute the average weights for model aggregation 
            data_len_ratios = compute_average_weights(workerIDlist)

            if args.baseline == 0:
                # Training ViT
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_model_%sblock_%sembed_%s.pth' % (args.start_model_ckpt_path, depths[0], depths[2], dataset)
                        model_name = 'decen_model_%sblock_%sembed.pth' % (depths[0], depths[2])

                        federated_vit_training(args, decen_record, load_ckpt_path, model_name, workerIDlist, ratio, depths, num_classes, train_dataset, val_dataset, num_T, this_r_eps, data_len_ratios, each_client_data_idxs, ctr)
         
                        ctr += 1
  
            else:
                # Training CNN
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_CNN_%sdepth_%s.pth' % (args.start_model_ckpt_path, depths, dataset)
                        model_name = 'decen_CNN_%sdepth.pth' % (depths)
                        global_save_path = '%sglobal/%s' % (args.save_path, model_name)


                        federated_cnn_training(args, decen_record, load_ckpt_path, model_name, workerIDlist, ratio, depths, num_classes, train_dataset, val_dataset, num_T, this_r_eps, data_len_ratios, each_client_data_idxs, ctr)
                        
                        ctr += 1

            # Record the training statistics
            record_string_to_number(decen_record)
            save_training_log(args, dataset, decen_record)

    else:

        assert args.start_client_number < args.num_total_workers

        super_train_idxs_list = {}
        each_client_data_idxs = {} 
        cen_data_idxs = {}
        for ratio in ratios:
            if dataset == "mini":
                create_cen_fed_scenarios(super_train_idxs_list, each_client_data_idxs, ratio, dataset, train_dataset, workerIDlist, parent_dataset)
            else:
                create_cen_fed_scenarios(super_train_idxs_list, each_client_data_idxs, ratio, dataset, train_dataset, workerIDlist)
            cen_data_idxs[ratio] = []
            for i in range(args.start_client_number):
                wid = workerIDlist[i]
                for data_idx in each_client_data_idxs[ratio][wid]:
                    cen_data_idxs[ratio].append(data_idx)

        if args.scenario == 0:

            # Centralized Training

            cen_record = {}
            ctr = 0

            if args.baseline == 0:
                # Training ViT
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_model_%sblock_%sembed_%s.pth' % (args.start_model_ckpt_path, depths[0], depths[2], dataset)
                        save_ckpt_path = '%sglobal/cen_model_%sblock_%sembed.pth' % (args.save_path, depths[0], depths[2])

                        centralized_vit_training(args, cen_record, depths, load_ckpt_path, save_ckpt_path, train_dataset, val_dataset, cen_data_idxs[ratio], num_T, this_r_eps, num_classes, ctr)
                                       
                        ctr += 1

            else:
                # Training CNN
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_CNN_%sdepth_%s.pth' % (args.start_model_ckpt_path, depths, dataset)
                        save_ckpt_path = '%sglobal/cen_CNN_%sdepth.pth' % (args.save_path, depths)
                        
                        centralized_cnn_training(args, cen_record, depths, load_ckpt_path, save_ckpt_path, train_dataset, val_dataset, cen_data_idxs[ratio], num_T, this_r_eps, num_classes, ctr)
                                       
                        ctr += 1

            # Record the training statistics
            record_string_to_number(cen_record)
            save_training_log(args, dataset, cen_record)

        else:

            # Federated Training

            decen_record = {}
            ctr = 0

            if args.baseline == 0:
                # Training ViT
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_model_%sblock_%sembed_%s.pth' % (args.start_model_ckpt_path, depths[0], depths[2], dataset)
                        
                        for cn in range(args.start_client_number, len(workerIDlist)+1):
                            new_wid_list = workerIDlist[:cn]
                            # Compute the average weights for model aggregation 

                            data_len_ratios = compute_average_weights(new_wid_list)

                            model_name = 'decen_model_%sblock_%sembed_%sclients.pth' % (depths[0], depths[2], cn)

                            federated_vit_training(args, decen_record, load_ckpt_path, model_name, new_wid_list, ratio, depths, num_classes, train_dataset, val_dataset, num_T, this_r_eps, data_len_ratios, each_client_data_idxs, cn)


            else:
                # Training CNN
                for ratio in ratios:
                    for depths in depths_list:
                        load_ckpt_path = '%sglobal/start_CNN_%sdepth_%s.pth' % (args.start_model_ckpt_path, depths, dataset)

                        for cn in range(args.start_client_number, len(workerIDlist)+1):
                            new_wid_list = workerIDlist[:cn]
                            # Compute the average weights for model aggregation 
                            data_len_ratios = compute_average_weights(new_wid_list)

                            model_name = 'decen_CNN_%sdepth_%sclients.pth' % (depths, cn)
                            
                            federated_cnn_training(args, decen_record, load_ckpt_path, model_name, new_wid_list, ratio, depths, num_classes, train_dataset, val_dataset, num_T, this_r_eps, data_len_ratios, each_client_data_idxs, cn)

            # Record the training statistics
            record_string_to_number(decen_record)
            save_training_log(args, dataset, decen_record)

                            

    del device

    time_elapsed1 = time.time() - since
    print('Training completes in {:.0f}m {:.0f}s'.format(time_elapsed1 // 60, time_elapsed1 % 60))

def create_cen_fed_scenarios(super_train_idxs_list, each_client_data_idxs, ratio, dataset_name, train_dataset, workerIDlist, parent_dataset=None):
    if ratio < 1.0:
        if dataset_name == "mini":
            assert parent_dataset
            super_train_idxs, _ = pdata.divide_dataset(train_dataset, ratio, dataset_name, parent_dataset)
        else:
            super_train_idxs, _ = pdata.divide_dataset(train_dataset, ratio, dataset_name)
    else:
        super_train_idxs = [i for i in range(len(train_dataset))]
    super_train_idxs_list[ratio] = super_train_idxs
    each_client_data_idxs[ratio] = {}
    each_data_size = len(super_train_idxs) // len(workerIDlist)
    data_amount_pass = 0
    for i in range(len(workerIDlist)):
        wid = workerIDlist[i]
        if i < len(workerIDlist) - 1:
            each_client_data_idxs[ratio][wid] = super_train_idxs[data_amount_pass:data_amount_pass+each_data_size]
        else:
            each_client_data_idxs[ratio][wid] = super_train_idxs[data_amount_pass:]
        data_amount_pass += each_data_size


def compute_average_weights(workerIDlist):
    data_len_ratios = {}
    for wid in workerIDlist:
        data_len_ratios[wid] = float(1.0) / float(len(workerIDlist))
    return data_len_ratios

def update_train_record(record, ctr, data_amount, num_params, train_losses):
    record[ctr] = {}
    record[ctr]["D"] = data_amount
    record[ctr]["N"] =  num_params
    record[ctr]['train_losses'] = train_losses

def update_val_record(record, ctr, test_loss, test_acc):
    record[ctr]['test_loss'] = test_loss
    record[ctr]['test_acc'] = test_acc

def centralized_vit_training(args, cen_record, depths, load_ckpt_path, save_ckpt_path, train_dataset, val_dataset, train_idxs, num_T, this_r_eps, num_classes, ctr):
    assert(os.path.exists(load_ckpt_path))
    # initialize model
    _ = train_VIT(args, None, None, depths, 1, num_classes, load_path=load_ckpt_path, save_path=save_ckpt_path)
    supervised_dataset = pdata.get_super_dataset(train_dataset, train_idxs)
    train_losses = []
    data_amount = 0
    assert(os.path.exists(save_ckpt_path))
    # training
    _, train_loss, num_params = train_VIT(args, None, supervised_dataset, depths, num_T*this_r_eps, num_classes, load_path=save_ckpt_path, save_path=save_ckpt_path)
    train_losses.append(train_loss)
    data_amount += len(supervised_dataset) * num_T * this_r_eps
    update_train_record(cen_record, ctr, data_amount, num_params, train_losses)
    # validation
    test_loss, test_acc = train_VIT(args, None, None, depths, 1, num_classes, dataset_val=val_dataset, eval_model=True, load_path=save_ckpt_path)
    update_val_record(cen_record, ctr, test_loss, test_acc)

def centralized_cnn_training(args, cen_record, depths, load_ckpt_path, save_ckpt_path, train_dataset, val_dataset, train_idxs, num_T, this_r_eps, num_classes, ctr):
    assert(os.path.exists(load_ckpt_path))
    # initialize model
    _ = train_CNN(args, None, None, depths, 1, num_classes, load_path=load_ckpt_path, save_path=save_ckpt_path)
    supervised_dataset = pdata.get_super_dataset(train_dataset, train_idxs)
    train_losses = []
    data_amount = 0
    assert(os.path.exists(save_ckpt_path))
    # training
    _, train_loss, num_params = train_CNN(args, None, supervised_dataset, depths, num_T*this_r_eps, num_classes, load_path=save_ckpt_path, save_path=save_ckpt_path)
    train_losses.append(train_loss)
    data_amount += len(supervised_dataset) * num_T * this_r_eps
    update_train_record(cen_record, ctr, data_amount, num_params, train_losses)
    # validation
    test_loss, test_acc = train_CNN(args, None, None, depths, 1, num_classes, dataset_val=val_dataset, eval_model=True, load_path=save_ckpt_path)
    update_val_record(cen_record, ctr, test_loss, test_acc)

def federated_vit_training(args, decen_record, load_ckpt_path, model_name, workerIDlist, ratio, depths, num_classes, train_dataset, val_dataset, num_T, this_r_eps, data_len_ratios, each_client_data_idxs, ctr):
    global_save_path = '%sglobal/%s' % (args.save_path, model_name)
    for wid in workerIDlist:
        save_ckpt_path = '%sworker/%s/%s' % (args.save_path, wid, model_name)
        assert(os.path.exists(load_ckpt_path))
        # initialize Vit on clients
        _ = train_VIT(args, wid, None, depths, 1, num_classes, load_path=load_ckpt_path, save_path=save_ckpt_path)
        if wid == workerIDlist[0]:
            # initialize Vit on erver
            _ = train_VIT(args, None, None, depths, 1, num_classes, load_path=load_ckpt_path, save_path=save_ckpt_path)
    data_amount = 0
    train_losses = []
    for T_counter in range(num_T):
        # server model aggreagtion
        _ = dist.neighbours_avg_agg(args, workerIDlist, data_len_ratios, model_name, global_save_path)
        train_loss_sum = 0
        for wid in workerIDlist:
            save_ckpt_path = '%sworker/%s/%s' % (args.save_path, wid, model_name)
            # find client dataset
            client_dataset = pdata.get_super_dataset(train_dataset, each_client_data_idxs[ratio][wid])
            assert(os.path.exists(global_save_path))
            # client training
            _, train_loss, num_params = train_VIT(args, wid, client_dataset, depths, this_r_eps, num_classes, load_path=global_save_path, save_path=save_ckpt_path)
            data_amount += len(client_dataset) * this_r_eps
            train_loss_sum += train_loss
        train_losses.append(train_loss_sum / len(workerIDlist))
    # final server model aggreagtion
    _ = dist.neighbours_avg_agg(args, workerIDlist, data_len_ratios, model_name, global_save_path)
    update_train_record(decen_record, ctr, data_amount, num_params, train_losses)
    # validation
    test_loss, test_acc = train_VIT(args, None, None, depths, 1, num_classes, dataset_val=val_dataset, eval_model=True, load_path=global_save_path)
    update_val_record(decen_record, ctr, test_loss, test_acc)

def federated_cnn_training(args, decen_record, load_ckpt_path, model_name, workerIDlist, ratio, depths, num_classes, train_dataset, val_dataset, num_T, this_r_eps, data_len_ratios, each_client_data_idxs, ctr):
    global_save_path = '%sglobal/%s' % (args.save_path, model_name)
    for wid in workerIDlist:
        save_ckpt_path = '%sworker/%s/%s' % (args.save_path, wid, model_name)
        assert(os.path.exists(load_ckpt_path))
        # initialize CNN on clients
        _ = train_CNN(args, wid, None, depths, 1, num_classes, load_path=load_ckpt_path, save_path=save_ckpt_path)
        if wid == workerIDlist[0]:
            # initialize CNN on server
            _ = train_CNN(args, None, None, depths, 1, num_classes, load_path=load_ckpt_path, save_path=save_ckpt_path)
    data_amount = 0
    train_losses = []
    for T_counter in range(num_T):
        # server model aggreagtion
        _ = dist.neighbours_avg_agg(args, workerIDlist, data_len_ratios, model_name, global_save_path)
        train_loss_sum = 0
        for wid in workerIDlist:
            save_ckpt_path = '%sworker/%s/%s' % (args.save_path, wid, model_name)
            # find client dataset
            client_dataset = pdata.get_super_dataset(train_dataset, each_client_data_idxs[ratio][wid])
            assert(os.path.exists(global_save_path))
            # client training
            _, train_loss, num_params = train_CNN(args, wid, client_dataset, depths, this_r_eps, num_classes, load_path=global_save_path, save_path=save_ckpt_path)
            data_amount += len(client_dataset) * this_r_eps
            train_loss_sum += train_loss
        train_losses.append(train_loss_sum / len(workerIDlist))
    # final server model aggreagtion
    _ = dist.neighbours_avg_agg(args, workerIDlist, data_len_ratios, model_name, global_save_path)
    update_train_record(decen_record, ctr, data_amount, num_params, train_losses)
    # validation
    test_loss, test_acc = train_CNN(args, None, None, depths, 1, num_classes, dataset_val=val_dataset, eval_model=True, load_path=global_save_path)
    update_val_record(decen_record, ctr, test_loss, test_acc)

def record_string_to_number(record):
    for k, v in record.items():
        assert "N" in v
        assert "train_losses" in v
        v["N"] =  "%.2f" % v["N"]
        v["train_losses"] =  ["%.4f" % loss for loss in v["train_losses"]]

def save_training_log(args, dataset_name, record):
    if args.scenario == 0:
        scene = "cen"
    else:
        scene = "decen%sc" % args.num_total_workers
    if args.baseline == 0:
        model_name = "Vit"
    else:
        model_name = "CNN"
    if args.experiment_id == 0:
        exp_name = "_n"
    elif args.experiment_id == 1: 
        exp_name = "_d" 
    elif args.experiment_id == 2: 
        exp_name = "_n_bridge_%sc" % args.start_client_number 
    elif args.experiment_id == 3: 
        exp_name = "_m_bridge"
    elif args.experiment_id == 4:
        exp_name = "_d_bridge"
    else:
        exp_name = "_T%s_bridge" % args.rounds
    with open('%s%s_%s%s_%s.txt' % (args.logs_save_path, model_name,  dataset_name, exp_name, scene), 'w') as file: 
        file.write(json.dumps(record))


if __name__ == '__main__':
    args = get_argparser().parse_args()
    main(args)