import torch
import os
import argparse
import numpy as np

import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from core.utils import fix_all_seed, parse_key_value_pairs, convert_dict_values
from core.data import DatasetEMNIST, DatasetFemnist
from core.data import federated_dataset_clusterisation
from core.cross_device import StoppingCriterion, StoppingCriterionAveraged, CrossDeviceOptimizerBase
from core.models import CNN22
from core.prox import CohortOptimizerProx, MimeLite
from core.sampler import NiceSampler, BlockSampler, StratifiedSampler

COHORT_OPTIMIZERS = {
    # "GD": CohortOptimizerSGD,
    # "LBFGS": CohortOptimizerLBFGS,
    # "FedAvg": CohortOptimizerFedAvg,
    # "Prox": CohortOptimizerProx,
    # "ProxFedAvg": CohortOptimizerProxFedAvg,
    "Prox": CohortOptimizerProx,
    "MimeLite": MimeLite,
}

OPTIMIZERS = {
    "Adam": torch.optim.Adam,
    "SGD": torch.optim.SGD,
}

SAMPLERS = {
    "Nice": NiceSampler,
    "Block": BlockSampler,
    "Stratified": StratifiedSampler,
}

DATASETS = {
    # "EMNIST_nonIID": (DatasetEMNIST, 2),
    # "EMNIST_IID": (DatasetEMNIST, 47),
    # "FEMNIST_nonIID": (DatasetFemnist, 1000),
    "FEMNIST_original": (DatasetFemnist, 200),
}


parser = argparse.ArgumentParser(description="Run training algorithm with specified parameters")

# seed
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# logging
parser.add_argument("--folder", type=str, default="logs/", help="Folder to save logs")
parser.add_argument("--log_every", type=int, default=5, help="Log every n epochs")
parser.add_argument("--wandb", action="store_true", help="Use wandb")
parser.add_argument("--plot", action="store_true", help="Plot the results")
# algorithm
parser.add_argument("--epochs", type=int, default=200, help="Number of epochs")
# cohort optimizer
parser.add_argument("--cohort_optimizer", type=str, default="Prox", choices=COHORT_OPTIMIZERS.keys(), help="Type of cross device algorithm")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size for cohort optimizer")
parser.add_argument("--gamma", type=float, default=None, help="Gamma parameter for the Prox optimizer")
# worker optimizer params
parser.add_argument("--worker_optimizer", type=str, default="Adam", choices=OPTIMIZERS.keys(), help="Type of optimizer")
parser.add_argument("--worker_optimizer_steps", type=int, default=1, help="Number of steps for worker optimizer")
parser.add_argument("--worker_optimizer_hparams", metavar="KEY=VALUE", nargs="+", help="Hyperparameters for worker optimizer")
# server optimizer params
parser.add_argument("--server_optimizer", type=str, default="SGD", choices=OPTIMIZERS.keys(), help="Type of optimizer")
parser.add_argument("--server_optimizer_steps", type=int, default=1, help="Number of steps for server optimizer (local communication rounds).")
parser.add_argument("--server_optimizer_hparams", metavar="KEY=VALUE", nargs="+", help="Hyperparameters for server optimizer")
# sampling
parser.add_argument("--sampler", type=str, default="Nice", choices=SAMPLERS.keys(), help="Type of sampler")
parser.add_argument("--minibatch_size", type=int, default=1, help="Mini batch size for worker optimizer")
parser.add_argument("--clusters_number", type=int, default=1, help="Number of clusters for clusterization")
parser.add_argument("--min_cluster_size", type=int, default=None, help="Minimum cluster size for clusterization")
parser.add_argument("--max_cluster_size", type=int, default=None, help="Maximum cluster size for clusterization")
parser.add_argument("--clusters_probabilities", type=float, nargs="+", default=None, help="Probabilities of clusters")
# hardware params
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Type of accelerator")
parser.add_argument("--process_count", type=int, default=1, help="Number of processes for dataloader")
# dataset
parser.add_argument("--dataset", type=str, default="FEMNIST_original", choices=DATASETS.keys(), help="Dataset to use")
parser.add_argument("--workers_count", type=int, default=10000, help="Number of workers")
# tolerance params
parser.add_argument("--tolerance_name", type=str, default=None, help="Tolerance name stopping optimizer")
parser.add_argument("--tolerance", type=float, default=None, help="Tolerance stopping optimizer")
parser.add_argument("--tolerance_window_size", type=int, default=1, help="Tolerance window size stopping optimizer")
parser.add_argument("--tolerance_max", action="store_true", help="Tolerance max stopping optimizer")

args = parser.parse_args()

# managing with folder
if not os.path.exists(args.folder):
    os.makedirs(args.folder)

# managing tolerance    
tolerance = None
if args.tolerance_name is not None:
    tolerance = StoppingCriterionAveraged(args.tolerance_name, args.tolerance, not args.tolerance_max, args.tolerance_window_size)

# logable params
params = {
    "seed": args.seed,
    "epochs": args.epochs,
    "local_epochs": args.server_optimizer_steps,
    "workers_count": args.workers_count,
    "log_every": args.log_every,
    "dataset": args.dataset,
    "cohort_optimizer": args.cohort_optimizer,
    "gamma": args.gamma,
    "sampler": args.sampler,
    "minibatch_size": args.minibatch_size,
    "worker_optimizer": args.worker_optimizer,
    "worker_optimizer_steps": args.worker_optimizer_steps,
    "worker_optimizer_hparams": args.worker_optimizer_hparams,
    "server_optimizer": args.server_optimizer,
    "server_optimizer_steps": args.server_optimizer_steps,
    "server_optimizer_hparams": args.server_optimizer_hparams,
}

COHORT_OPTIMIZER_HPARAMS = {
    "worker_optimizer": OPTIMIZERS[args.worker_optimizer],
    "worker_optimizer_steps": args.worker_optimizer_steps,
    "worker_optimizer_hparams": convert_dict_values(parse_key_value_pairs(args.worker_optimizer_hparams)),
    "server_optimizer": OPTIMIZERS[args.server_optimizer],
    # "server_optimizer_steps": args.server_optimizer_steps,
    "server_optimizer_hparams": convert_dict_values(parse_key_value_pairs(args.server_optimizer_hparams)),
    "gamma": args.gamma,
    "batch_size": args.batch_size,
}

# add all arg params to filename
params['filename'] = args.folder + "/" + '-'.join([f'{params[key]}' for key in params.keys()])
# params['filename'] += '-'.join([f'{key}:{params[key]}' for key in params.keys() if key != 'cohort_optimizer_hparams' and key != 'filename'])
# params['filename'] += '-' + f'cohort_optimizer:{args.cohort_optimizer}' + \
#     '_' + '-'.join([f"{key}:{COHORT_OPTIMIZER_HPARAMS[key]}" for key in COHORT_OPTIMIZER_HPARAMS.keys() if key != 'optimizer' and key != 'optimizer_hparams']) + \
#     '_' + '-'.join([f"{key}:{COHORT_OPTIMIZER_HPARAMS['optimizer_hparams'][key]}" for key in COHORT_OPTIMIZER_HPARAMS["optimizer_hparams"].keys()])
# params['filename'] += '-' + f'cohort_optimizer:{args.cohort_optimizer}'
params['filename'] += '.csv'

params['cohort_optimizer_hparams'] = COHORT_OPTIMIZER_HPARAMS

# Updating the params dictionary
params.update({
    "tolerance": tolerance,
    "device": args.device,
    "plot": args.plot,
    "tqdm": True,
    "loss": torch.nn.CrossEntropyLoss(),
    "cohort_optimizer": COHORT_OPTIMIZERS[args.cohort_optimizer],
    "process_count": args.process_count,
    "batch_size": args.batch_size,
})

dataset_class, iid_degree = DATASETS[args.dataset]
dataset = dataset_class(
    workers_count = params["workers_count"],
    iid_degree = iid_degree,
    seed = params["seed"],
)
dataset_train, workers_data_indices = dataset(train=True)
dataset_test, _ = dataset(train=False)
NUM_CLASSES = dataset.classes_count()
fix_all_seed(params["seed"])
params["model"] = CNN22(NUM_CLASSES).to(params["device"])

workers_to_cluster = federated_dataset_clusterisation(dataset_train, workers_data_indices, args.clusters_number, 
                                                      min_cluster_size=args.min_cluster_size, 
                                                      max_cluster_size=args.max_cluster_size,
                                                      seed=params["seed"])

if args.sampler == "Nice":
    params["sampler"] = SAMPLERS[args.sampler](np.arange(params["workers_count"]), args.minibatch_size, seed=params["seed"])
elif args.sampler == "Block":
    params["sampler"] = SAMPLERS[args.sampler](np.arange(params["workers_count"]), workers_to_cluster, np.array(args.clusters_probabilities), seed=params["seed"])
elif args.sampler == "Stratified":
    params["sampler"] = SAMPLERS[args.sampler](np.arange(params["workers_count"]), workers_to_cluster, seed=params["seed"])
else:
    raise ValueError(f"Sampler {args.sampler} is not supported")

if args.wandb:
    # ENTER YOUR W&B PARAMS HERE
    wandb_params = {
        "entity": "example",
        "project": "example",
        "config": params,
        "name": os.path.basename(params["filename"]),
        "key": # ENTER YOUR W&B KEY HERE
    }

algorithm = CrossDeviceOptimizerBase(workers_count=params["workers_count"],
                  batch_size=params["batch_size"],
                  epochs=params["epochs"],
                  seed=params["seed"],
                  sampler=params["sampler"],
                  model=params["model"],
                  loss=params["loss"],
                  cohort_optimizer=params["cohort_optimizer"],
                  cohort_optimizer_hparams=params["cohort_optimizer_hparams"],
                  local_epochs=params["local_epochs"],
                  device=params["device"],
                  plot=params["plot"],
                  tqdm=params["tqdm"],
                  log_every=params["log_every"],
                  logs_filename=params["filename"],
                  process_count=params["process_count"],
                  stopping_criterion=params["tolerance"],
                  wandb_params=wandb_params if args.wandb else None,
                  )

algorithm.train(dataset_train, workers_data_indices, dataset_test)
