import sys
import logging
from cvxpy import loggamma
import fedml
import torch
from data_loader import load_partition_data_census
from fedml.simulation import SimulatorSingleProcess as Simulator
from standard_trainer import StandardTrainer
import wandb
import pathlib
import os
import time

from model import TwoNN


census_input_shape_dict = {'income':54, 
'health':89, 'employment':98}


# os.environ["WANDB_MODE"] = "offline"
def load_data(args):
    fedml.logging.info("load_data. dataset_name = %s" % args.dataset)
    
    # all clients in FL
    args.users = [i for i in range(51)]
    (
        client_num,
        _,
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        val_data_global,
        train_data_local_num_dict,
        test_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        val_data_local_dict,
        class_num,
        unselected_data_local_dict
    ) =load_partition_data_census(args.users,args)

    args.client_num_in_total = client_num
    dataset = [
        train_data_num,
        test_data_num,
        train_data_global,
        test_data_global,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        class_num,
    ]
    return dataset, class_num





def main():
    # init FedML framework
    args = fedml.init()
    args.run_folder = 'results/{}/run_{}'.format(args.task,args.random_seed)
    pathlib.Path(args.run_folder).mkdir(parents=True, exist_ok=True)
    start_time = time.time()
    device = fedml.device.get_device(args)
    dataset, output_dim = load_data(args)
    print('load dataset time {}'.format(time.time()-start_time))
    if args.model == 'two-layer':
        model = TwoNN(census_input_shape_dict[args.task],args.num_hidden,output_dim)
    # trainer = StandardTrainer(model)
    print('load model time {}'.format(time.time()-start_time))
    # start training

    simulator = Simulator(args, device, dataset, model)
    simulator.run()

    print('finishing time {}'.format(time.time()-start_time))

    print(simulator.fl_trainer)
    if args.federated_optimizer in ['FedNova','SCAFFOLD']:
        torch.save(simulator.fl_trainer.model_global.state_dict(),
                        os.path.join(args.run_folder, "%s.pt" %(args.save_model_name))) # check the fedavg model name
    elif args.federated_optimizer in ['FedProx','FedOpt','FedDyn','Mime']:
        torch.save(simulator.fl_trainer.model_trainer.get_model_params(),
                        os.path.join(args.run_folder, "%s.pt" %(args.save_model_name))) # check the fedavg model name


if __name__ == "__main__":
    main()