import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

import subprocess
import torch
from multiprocessing import Pool
from Common_functions import weights_agg
import pickle
import argparse


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('true'):
        return True
    elif v.lower() in ('false'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
def run_client(client_id, gpu, encoder_runs, save_path_encoder='Encoder_weights_saved', cuda=True):
    # Command to run your client.py script with specified client ID and GPU
    cmd = f"python3 Encoder_Local_Train.py --client_id {client_id} --gpu {gpu} --encoder_runs {encoder_runs} --save_path_encoder {save_path_encoder} --cuda {cuda}"
    # Run the command
    subprocess.run(cmd, shell=True)

def main(args):
    num_clients = args.num_clients
    num_gpus = args.num_gpus
    num_rounds=args.num_rounds;
    save_path=args.save_path
    start_from_saved_checkpoint=args.start_from_saved_checkpoint
    save_path_encoder=args.save_path_encoder
    # Create a pool of workers, one for each GPU

    if start_from_saved_checkpoint:
        offset=1
    else:
        offset=0

    for encoder_runs_ in range(num_rounds):
        encoder_runs=encoder_runs_+offset;

        client_batches=num_clients//num_gpus+1
        rem=num_clients%num_gpus

        for i in range(client_batches):
            tasks=[];
            if i<client_batches-1:
                for g in range(num_gpus):
                    tasks+=[(num_gpus*i+g, g, encoder_runs, save_path_encoder)]
            else:
                for g in range(rem):
                    tasks+=[(num_gpus*i+g, g, encoder_runs, save_path_encoder)]

            with Pool(num_gpus) as pool:
                pool.starmap(run_client, tasks)

        local_models=[];
        for client_id in range(num_clients):
            local_models+=[torch.load(save_path + '_Model_for_client_'+str(client_id)+'.pt', map_location='cuda:0')]

        agg_weights = weights_agg([local_models[i].encoder.state_dict() for i in range(num_clients)])
        with open(save_path_encoder, 'wb') as fp:
            pickle.dump(agg_weights, fp)
        del local_models

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_from_saved_checkpoint', type=str2bool, default=False)
    parser.add_argument('--num_clients', type=int, default=100)
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--num_rounds', type=int, default=10)
    parser.add_argument('--num_gpus', type=int, default=8)
    parser.add_argument('--save_path_encoder', type=str, default='Encoder_weights_saved')
    args = parser.parse_args()
    main(args)
