# -*- coding: utf-8 -*-
import os, numpy, random, time
import torch

import argparse
import server_layer as trainer_server
import worker as trainer_worker
import add_parser_arguments

from utils import log

def new_arguments(parser):
    # Method description
    parser.add_argument('--method', type=str, default='FIARSE', help='Running algorithm')

    # Dataset 
    parser.add_argument('--root', type=str, default='~/dataset/', help='The root of dataset')
    parser.add_argument('--dataset', type=str, default='cifar10', help='The name of dataset used')
    parser.add_argument('--partitioner', type=str, default='dirichlet', help='How to partition the dataset')#'balanced', 'dirichlet', 'pathological', 'iid'

    # Model 
    parser.add_argument('--model', type=str, default='BetaResNet18_sbn', help='The name of model used') 

    # Other settings
    parser.add_argument('--num-workers',type=int, default=100, help='Total number of workers')
    parser.add_argument('--num-part', type=int, default=10, help='Number of partipants')

    parser.add_argument('--bsz', type=int, default=32, help='Batch size for training dataset')    
    parser.add_argument('--seed', type=int, default=0, help='Seed for randomization')
    parser.add_argument('--gpu-idx', default=[0], action='extend', nargs='+', help='Index of GPU')

    parser.add_argument('--save-log', default='LOG.txt')
    parser.add_argument('--save-results', default='RESULTS.txt')
    parser.add_argument('--lambda0', default=0)
    parser.add_argument('--lambda1', default=0)
    return parser.parse_known_args()[0]


if __name__ == "__main__":
    time_run_start = time.time()
    parser = argparse.ArgumentParser(description='Model Training')
    args = new_arguments(parser)
    import importlib
    add_parser_arguments.new_arguements(parser)
    args = parser.parse_known_args()[0]
    
    # set random seed 
    args.seed = 0
    torch.manual_seed(args.seed)
    numpy.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    args.method = 'FedRKMGC'
    args.m = 1.5 #1.5
    args.alpha_p = 500 #500

    
    args.gpu_idx = [0]
    args.T = 1000
    args.lr = 0.01
    args.momentum = 0.8    
    args.partitioner = 'dirichlet'#'balanced', 'dirichlet', 'pathological', 'iid'
    
    args.model_size = ['1.0']
    args.model_dist = ['100']
    args.case = 7

    args.num_workers = 100
    args.num_part = 10
    args.bsz = 20
    args.K_unit = 'epochs'#iterations, epochs, total_size
    args.K = 5

    args.delta = 0.03
    args.lrd = 1
    args.lambdai_coef = 1
    args.prox_coef = 0

    args.dataset = 'cifar100'
    args.model = 'BetaResNet18_sbn'
    args.save_freq = 1


# ==================================================================================


    title = f'C{args.case}_{args.dataset}[{args.method}_D{args.delta}_lr{args.lrd}_{args.lambdai_coef}_{args.prox_coef}_m{args.m}_al{args.alpha_p}]lr{args.lr}_m{args.momentum}_{args.num_part}_{args.num_workers}_R{args.T}{args.K_unit}{args.K}_B{args.bsz}{args.partitioner}0.3_S{args.seed}_F{args.save_freq}_{args.model}_62'
    save_name = '0save_' + title
    save_log = '0LOG_' + title + '.txt'
    save_acc = '0acc_' + title + '.txt'
    save_size = '0size_' + title + '.txt'
    save_worker = '0worker_' + title + '.txt'
    save_train_loss = '0LT_' + title + '.txt'
    save_eval_loss = '0LE_' + title + '.txt'
    save_eval_acc = '0LA_' + title + '.txt'

    args.save_log = save_log
    args.save_name = save_name
    args.save_acc = save_acc
    args.save_size = save_size
    args.save_worker = save_worker
    args.save_train_loss = save_train_loss
    args.save_eval_loss = save_eval_loss
    args.save_eval_acc = save_eval_acc
    with open(save_log, "w+") as f:
        f.write('LOG\n\n\n')
        f.write("Arguments:\n")
        for arg, value in vars(args).items():
            f.write(f"{arg}: {value}\n")
        f.write("\n\n\n")
    with open(save_acc, "w+") as f:
        f.write('')
    with open(save_size, "w+") as f:
        f.write('')
    with open(save_worker, "w+") as f:
        f.write('')
    with open(save_train_loss, "w+") as f:
        f.write('')
    with open(save_eval_loss, "w+") as f:
        f.write('')
    with open(save_eval_acc, "w+") as f:
        f.write('')    
    
    




    method = args.method
    import sys
    sys.path.insert(1, '../')
    dataset = importlib.import_module('data.{}.data'.format(args.dataset))
    model = importlib.import_module('model.{}.model'.format(args.dataset))
    model = getattr(model, args.model)()

    worker_trainers = []
    workers = numpy.arange(args.num_workers) + 1
    dataset_server = dataset.ServerLoader(parser=parser, partitioner=args.partitioner, workers=workers, dataset_root=args.root) 
    dataset_client = dataset.ClientLoader(parser=parser, partitioner=args.partitioner, ranks=workers, workers=workers, tags=['train', 'test'], dataset_root=args.root)

    for idex in range(args.num_workers+1):
        cpu = torch.device('cpu')
        gpu = torch.device('cuda:{}'.format(args.gpu_idx[0])) if torch.cuda.is_available() else torch.device('cpu')
        
        if idex == 0:       # This is server 
            test_data_loader = dataset_server.get_loader(tag='test', batch_size=args.bsz)
            server = trainer_server.Server(num_workers=args.num_workers, num_part=args.num_part, args=args, model=model, train_data_loader=None, test_data_loader=test_data_loader, multiprocessing=False, cpu=cpu, gpu=gpu)
            
        else:               # This is worker 
            train_data_loader = dataset_client.get_loader(rank=idex, tag='train', batch_size=args.bsz)
            test_data_loader = dataset_client.get_loader(rank=idex, tag='test', batch_size=args.bsz)
            worker = trainer_worker.Worker(rank=idex, args=args, model=model, train_data_loader=train_data_loader, test_data_loader=test_data_loader, multiprocessing=False, cpu=cpu, gpu=gpu)
            
            worker_trainers.append(worker)
            
    server.start(worker_trainers=worker_trainers)
    time_run_end = time.time() - time_run_start
    log(args.save_log, f'time_run: {time_run_end/60:.2f}min')
    log(args.save_log, f'{args.save_log}')
