from quantizer import *
from server import *
from client import *
from dataset_manager import *
from model_manager import *
from trainer import *

def run(setups, dataset_name, log_period, **kwargs):
    trainers = []
    logs = {} 
    for case in setups:
        logs[case] = {}
    
    train_sets_list, test_set = get_datasets(dataset_name, **kwargs)
    initial_model = get_model(dataset_name)
    for case in setups:
        print(f"--- {str(case)} ---")
        start = time.time()

        ## Setting run parameters
        case_params = setups[case]
        algorithm    = case_params['algorithm']
        client_count = case_params['client count']
        local_step   = case_params['local step']
        group_count  = case_params['group count']
        time_limit   = case_params['time_limit']
        lr           = case_params['lr']
        gpu_ids      = case_params['gpu_ids']           

        if 'swt' in case_params.keys():
            Trainer.server_waiting_time = case_params['swt']
        if 'sit' in case_params.keys():
            Server.server_interaction_time = case_params['sit']
        if 'qnnbits' in case_params.keys():
            #print(case_params.keys())
            for n,m in initial_model.named_modules():
                #print('MMMMMM', m, 'n=',n, isinstance(m,torch.nn.Conv2d))
                if isinstance(m,torch.nn.Conv2d):
                    #print('NNNNNNN', n)
                    if ('downsample' not in n and 'layer' in n) :
                        m.quantizeFwd = True
                        m.quantizeBwd = True #False
                        m.abits = case_params['qnnbits']
                        m.wbits = case_params['qnnbits']
                        print(m.abits)
            
        server_averaging = case_params['server_averaging'] if 'server_averaging' in case_params.keys() else True
        client_averaging = case_params['client_averaging'] if 'client_averaging' in case_params.keys() else True

        quantization_params = case_params["quantizer"]
        quantizer = get_quantizer(**quantization_params)


        print(initial_model)
        trainer = Trainer(  algorithm = algorithm,
                            dataset_name = dataset_name,
                            client_count = client_count,
                            train_sets_list = train_sets_list,
                            test_set = test_set,
                            local_step = local_step,
                            group_count = group_count,
                            quantizer = quantizer,
                            initial_model = initial_model, 
                            log_period = log_period,
                            gpu_ids    = gpu_ids,
                            server_averaging = server_averaging,
                            client_averaging = client_averaging)

        ## Keeping all the trainers is so memory consuming, so comment out next line if you don't need them.
        #trainers.append(trainer)   

        history = trainer.train(lr=lr, time_limit=time_limit)

        for key in history[0].keys():
            logs[case][key] = [x[key] for x in history]
        end = time.time()
        print(f"Finished in {end - start}")
            
    return logs, trainers 
