import os
import argparse
import random
import sys

import numpy as np
import torch
from mpi4py import MPI

from client import Client
from trainer import Trainer
from privacy_checker import check_privacy
import utils

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", help="Name of the dataset: [mnist, fashion-mnist, cifar10, cifar100].",
                    type=str, choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100'], default="mnist")
parser.add_argument("--algorithm",
                    help="Algorithm: ['Regular', 'Joint', 'FedAvg','DFedEM', 'Federico', 'FedFomo', 'FedAvg+', 'CFL'].",
                    type=str, choices=['Regular', 'Joint', 'FedAvg', 'DFedEM', 'Federico', 'FedFomo', 'FedAvg+', 'CFL'],
                    default="FedAvg")
parser.add_argument("--n_components", help='number of components in the mixture of distribution', type=int,
                    default=1)
parser.add_argument("--n_clusters", help="number of components/clusters; default is -1",
                    type=int, default=-1)
parser.add_argument("--frac", help="fraction of training dataset used",
                    type=float, default=0.5)
parser.add_argument("--class_seed", help="Random seed for data partition. Set to -1 for class_seed==seed",
                    type=int, default=-1)
parser.add_argument("--private_model_type", help="Private model architecture.",
                    type=str, choices=['LeNet5', 'MLP', 'CNN', 'CNN1', 'CNN2', 'ResNet18'], default="MLP")
parser.add_argument("--result_path", help="Where to save results.",
                    type=str, default="./results")
parser.add_argument("--data_path", help="Where to find the data.",
                    type=str, default="./datasets")
parser.add_argument("--n_clients", help="Number of clients.",
                    type=int, default=2)
parser.add_argument("--n_neighbors", help="Number of neighbors of Federico to get model from in each round.",
                    type=int, default=1)
parser.add_argument("--use_private_SGD", help="[int as bool] Use private SGD or not.",
                    type=int, default=0)
parser.add_argument("--optimizer", help="Optimizer.",
                    type=str, default='adam')
parser.add_argument("--n_epochs", help="Number of model backprop epochs per round.",
                    type=int, default=1)
parser.add_argument("--n_local_epochs", help="Number of model backprop epochs per round for FedFomo only.",
                    type=int, default=1)
parser.add_argument("--lr", help="Learning rate.",
                    type=float, default=0.001)
parser.add_argument("--momentum", help="Momentum for SGD.",
                    type=float, default=0.9)
parser.add_argument("--delta", help="delta parameter for DP SGD.",
                    type=float, default=0.00001)
parser.add_argument("--noise_multiplier", help="Gaussian noise deviation for DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--l2_norm_clip", help="L2 norm maximum for clipping in DP SGD.",
                    type=float, default=1.0)
parser.add_argument("--n_rounds", help="Number of FL rounds.",
                    type=int, default=300)
parser.add_argument("--batch_size", help="Batch size during training.",
                    type=int, default=50)
parser.add_argument("--device", help="Which cuda device to use.",
                    type=int, default=0)
parser.add_argument("--seed", help="Random seed.",
                    type=int, default=1)
parser.add_argument("--verbose", help="Verbose level.",
                    type=int, default=0)
parser.add_argument("--eval_ratio", help="Ratio of evaluation data in each client",
                    type=float, default=0.2)
parser.add_argument("--cw_ratio", help="Ratio of component weight training data in each client for FedFomo",
                    type=float, default=0.2)
parser.add_argument("--cw_momentum", help="Momentum update for client weights",
                    type=float, default=0.9)
parser.add_argument("--greedy_eps", help="Epsilon parameter for the epsilon-greedy sampling in training.\
                                          smaller epsilon results in more greedy-like sampling",
                    type=float, default=1.0)
parser.add_argument('--rerun', help="Rerun if result exists", action='store_true')
args = parser.parse_args()

comm = MPI.COMM_WORLD
if args.device == -1:
    args.device = torch.device('cpu')
else:
    if torch.cuda.device_count() >= comm.size:
        args.device = torch.device(
            f"cuda:{comm.rank}" if torch.cuda.is_available() else "cpu")
    else:
        args.device = torch.device(
            f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")

result_path = utils.get_path(args, args.algorithm, args.class_seed)
if not args.rerun:
    if os.path.exists(os.path.join(result_path, f"seed_{args.seed}.npz")):
        print("Already have a result for this conifugration, quitting")
        sys.exit(0)

log_path = os.path.join(result_path, "logs")
if not os.path.exists(log_path) and comm.rank == 0:
    os.makedirs(log_path)
comm.Barrier()

# Data preparation
train_X, train_y, test_X, test_y = utils.get_data(args)
test_data = (test_X, test_y)

args.n_class = len(np.unique(train_y))
args.in_channel = train_X.shape[1]
client_data_list, client_major_classes = utils.partition_data(train_X, train_y, args)
# Seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Clients
logger = utils.get_logger(os.path.join(log_path, f"seed_{args.seed}_client_{comm.rank}.log"))

if args.algorithm == 'Joint':
    combined_data_x = np.concatenate([data[0] for data in client_data_list], axis=0)
    combined_data_y = np.concatenate([data[1] for data in client_data_list], axis=0)
    client_data = (combined_data_x, combined_data_y)
else:
    client_data = client_data_list[comm.rank]

# if args.use_private_SGD and comm.rank == 0:
#     epsilon, alpha = check_privacy(args)
#     logger.info(f"Expected privacy use is ε={epsilon:.2f} and δ={args.delta:.5f} at α={alpha:.2f}")

client = Client(client_data, args, client_major_classes[comm.rank])
logger.info("Client {} data size: {}; Training data size {}; Testing data size {}".
            format(comm.rank, len(client.private_data[1]), len(client.private_train_data[1]),
                   len(client.private_test_data[1])))
logger.info("Private data classes: {}".format(client_major_classes[comm.rank]))
trainer = Trainer(args)

results = trainer.train(client, test_data, comm, logger, args)
results['client_major_classes'] = client_major_classes

if comm.rank == 0:
    np.savez(os.path.join(result_path, f"seed_{args.seed}.npz"), **results)
    logger.info("Total training time {}".format(results['training_time']))
    if 'comm time' in results: logger.info("Total comm time {}".format(results['comm time']))
