"""Main module for running FEMNIST experiments."""
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import pathlib
import time
import json
import flwr
import flwr as fl
import hydra
import pandas as pd
import argparse
import torch
import numpy as np
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader
from typing import Type, Union, Dict, Callable, Optional, Tuple, List
from logging import INFO

from model import return_model, test, train
from client import create_client_fn
from dataset.dataset import create_federated_dataloaders
from flwr.common import ndarrays_to_parameters
from flwr.common.logger import log
from flwr.common.typing import NDArrays, Scalar
from strategy import FedAvgSameClients
from utils import setup_seed, weighted_average, plot_metric_from_history, mia_attack, set_parameters

def fit_config(server_round: int, unlearning_round: int = -1) -> Dict[str, Scalar]:
    """Return training configuration dict for each round."""
    config = {
        "current_round": server_round,
        "unlearning_round": unlearning_round,
    }
    return config

def get_evaluate_fn(
    net: torch.nn.Module,
    trainloader: DataLoader,
    central_testloader: DataLoader,
    testloader: DataLoader,
    args: argparse.Namespace
):
    """Return an evaluation function for server-side evaluation."""
    def evaluate(
        server_round: int,
        new_parameters: NDArrays,
        config: Dict[str, Scalar],
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        set_parameters(net, new_parameters)
        loss, accuracy = test(net, central_testloader, method=args.method, device=args.device)
        if server_round == args.num_rounds:
            # Return dummy mia_metrics to maintain compatibility
            metrics = {
                "acc": 0.0,
                "precision": 0.0,
                "recall": 0.0
            }
            # Save metrics to a file
            with open(f"{args.results_dir_path}/{args.dataset}{args.affix}_mia_metrics.json", 'w') as f:
                json.dump(metrics, f)
            return loss, {'accuracy': accuracy, 'mia_metrics': metrics}

        return loss, {'accuracy': accuracy}

    return evaluate

@hydra.main(config_path="conf", config_name="table2_cifar100_FATS_baseline.yaml", version_base=None)
def main(cfg: OmegaConf):
    """Main function for running FEMNIST experiments."""
    # Ensure reproducibility
    setup_seed(cfg.random_seed)
    with open_dict(cfg):
        cfg.affix = time.strftime("%Y%m%d-%H%M%S")
    log(INFO, "config: %s", cfg)
    # Save the results
    results_dir_path = pathlib.Path(cfg.results_dir_path)
    results_dir_path.mkdir(parents=True, exist_ok=True)

    device = torch.device(cfg.device)
    # Create datasets for federated learning
    trainloaders, valloaders, testloaders, central_testloader = create_federated_dataloaders(
        config=cfg,
        dataset=cfg.dataset,
        sampling_type=cfg.distribution_type,
        dataset_fraction=cfg.dataset_fraction,
        batch_size=cfg.batch_size,
        train_fraction=cfg.train_fraction,
        validation_fraction=cfg.validation_fraction,
        test_fraction=cfg.test_fraction,
        random_seed=cfg.random_seed,
        method=cfg.method,
        min_samples_per_client=cfg.min_samples_per_client
    )

    net = return_model(cfg.dataset, cfg.num_classes).to(device)
    log(INFO, "net: %s", net)
    # Ensure the model parameters are properly set
    set_parameters(net, [val.cpu().numpy() for _, val in net.state_dict().items()])
    net_parameters = [val.cpu().numpy() for _, val in net.state_dict().items()]
    total_n_clients = len(trainloaders)
    log(INFO, "Total number of clients: %s", total_n_clients)

    # Create the client creation function
    client_fn = create_client_fn(
        trainloaders=trainloaders,
        valloaders=valloaders,
        testloaders=testloaders,
        device=device,
        method=cfg.method,
        num_epochs=cfg.epochs_per_round,
        learning_rate=cfg.learning_rate,
        dataset=cfg.dataset,
        num_classes=cfg.num_classes,
        num_batches=cfg.batches_per_round,
    )

    flwr_strategy: Union[Type[flwr.server.strategy.FedAvg], Type[FedAvgSameClients]]
    if cfg.same_train_test_clients:
        flwr_strategy = FedAvgSameClients
    else:
        flwr_strategy = flwr.server.strategy.FedAvg

    strategy = flwr_strategy(
        args=cfg,
        min_available_clients=total_n_clients,
        fraction_fit=0.001,
        min_fit_clients=cfg.num_clients_per_round,
        fraction_evaluate=0.001,
        min_evaluate_clients=cfg.num_clients_per_round,
        evaluate_fn=get_evaluate_fn(net, trainloaders, central_testloader, testloaders, cfg),
        on_fit_config_fn=fit_config,
        fit_metrics_aggregation_fn=weighted_average,
        evaluate_metrics_aggregation_fn=weighted_average,
        initial_parameters=ndarrays_to_parameters(net_parameters),
    )

    # Adjust client resources to enable simultaneous training
    client_resources = None
    if device.type == "cuda":
        # Allocate a fraction of GPU per client
        client_resources = {"num_gpus": 0.01}  # Adjust the fraction based on available GPUs
    else:
        # Use CPU resources
        client_resources = {"num_cpus": 1}



    # Start simulation with multiple clients training simultaneously
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=total_n_clients,
        config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
        strategy=strategy,
        client_resources=client_resources,
    )

    # Save and plot results
    distributed_history_dict = {}
    for metric, round_value_tuple_list in history.metrics_distributed.items():
        distributed_history_dict["distributed_test_" + metric] = [val for _, val in round_value_tuple_list]
    for metric, round_value_tuple_list in history.metrics_distributed_fit.items():
        distributed_history_dict["distributed_" + metric] = [val for _, val in round_value_tuple_list]
    distributed_history_dict["distributed_test_loss"] = [val for _, val in history.losses_distributed]

    results_df = pd.DataFrame.from_dict(distributed_history_dict)
    results_df.to_csv(results_dir_path / f"history_{cfg.dataset}_{cfg.affix}.csv")
    np.save(f"{cfg.results_dir_path}/history_{cfg.dataset}_{cfg.affix}", history)

    plot_metric_from_history(
        history,
        cfg.results_dir_path,
        cfg.dataset + cfg.affix,
        metric_type='distributed',
    )
    # Save configuration to a yaml file
    with open(f"{cfg.results_dir_path}/{cfg.dataset}{cfg.affix}.yaml", 'w') as f:
        OmegaConf.save(cfg, f)

    log(INFO, "save history done")

if __name__ == "__main__":
    main()
