import os
import pickle
from datetime import datetime
import pandas as pd
import wandb

import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.utils import call, instantiate
from omegaconf import DictConfig, OmegaConf
import flwr as fl
from flwr.server.server import Server
from flwr.server.client_manager import SimpleClientManager
from pathlib import Path

import warnings
from sentry_sdk.hub import SentryHubDeprecationWarning

warnings.filterwarnings("ignore", category=SentryHubDeprecationWarning)

from datasets_local import load_datasets
from utils_local import gen_evaluate_fn, fedavg_gen_evaluate_fn, fedspars_gen_evaluate_fn

@hydra.main(config_path="conf", config_name="base", version_base=None)
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))

    # Step 2: Prepare the dataset
    print("Preparing the dataset...")
    trainloaders, valloaders, testloader = load_datasets(
            config=cfg.dataset,
            num_clients=cfg.num_clients,
            val_ratio=cfg.dataset.val_split,
        )
    
    # 3. Define your clients
    # pylint: disable=protected-access
    client_fn = call(
            cfg.client_fn,
            trainloaders,
            valloaders,
            model=cfg.model,
        )

    if cfg.wandb == True:
        ## Setup wandb
        wandb.login()

        wandb.init(
        # Set the project where this run will be logged
        project="sa-pef",
        name=cfg.wandb_name,
        # Track hyperparameters and run metadata
        config={
        "architecture": cfg.model._target_,
        "dataset": cfg.dataset_name,
        "approach": cfg.approach,
        "alpha": cfg.alpha,
        "learning_rate": cfg.learning_rate,
        "num_rounds": cfg.num_rounds,
        "num_clients": cfg.num_clients,
        "clients_per_round": cfg.clients_per_round,
        "local_epochs": cfg.num_epochs,
    })

    device = cfg.server_device
    if cfg.strategy._target_ == "Spars_strategy_cser.CSER" or cfg.strategy._target_ == "Spars_strategy_topk.FedSpars":
        evaluate_fn = fedspars_gen_evaluate_fn(testloader, device=device, model=cfg.model, approach=cfg.approach, comp_type=cfg.comp_type, sparsify_by=cfg.sparsify_by, partitioning=cfg.partitioning, num_clients=cfg.num_clients, clients_per_round=cfg.clients_per_round, cfg_wandb=cfg.wandb)
    elif cfg.strategy._target_ == "strategy_fedavg.FedAvg" or cfg.strategy._target_ == "strategy_fedprox.FedProx":
        evaluate_fn = fedavg_gen_evaluate_fn(testloader, device=device, model=cfg.model, cfg_wandb=cfg.wandb)
    else:
        evaluate_fn = gen_evaluate_fn(testloader, device=device, model=cfg.model, cfg_wandb=cfg.wandb)

    # 4. Define your strategy
    strategy = instantiate(
            cfg.strategy,
            evaluate_fn=evaluate_fn,
            net=cfg.model,
        )

    # 5. Define your server
    server = Server(strategy=strategy, client_manager=SimpleClientManager())

     # 6. Start Simulation
    start = datetime.now()
    history = fl.simulation.start_simulation(
        server=server,
        client_fn=client_fn,
        num_clients=cfg.num_clients,
        config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
        client_resources={
            "num_cpus": cfg.client_resources.num_cpus,
            "num_gpus": cfg.client_resources.num_gpus,
        },
        strategy=strategy,
    )
    end = datetime.now()
    time_taken = end-start

    print(history)

    save_path = HydraConfig.get().runtime.output_dir
    print(save_path)

    # 7. Save your results
    with open(os.path.join(save_path, "history.pkl"), "wb") as f_ptr:
        pickle.dump(history, f_ptr)


    if cfg.strategy._target_ == "strategy_fedavg.FedAvg" or cfg.strategy._target_ == "strategy_fedprox.FedProx":
        results = pd.DataFrame({"time_taken": [time_taken],
                            "dataset": [cfg.dataset_name],
                            "partitioning": [cfg.partitioning], 
                            "num_clients": [cfg.num_clients],
                            "clients_per_round": [cfg.clients_per_round],
                            "num_rounds": [cfg.num_rounds],
                            "approach": [cfg.approach],
                            "alpha": [cfg.alpha],
                            "local_epochs": [cfg.num_epochs], 
                            "learning_rate": [cfg.learning_rate], 
                            "losses": [history.losses_centralized],
                            "accs": [history.metrics_centralized["accuracy"]],
                            "uplink_bits_total": [history.metrics_distributed_fit["uplink_bits_total"]],
                        })
    else:
        results = pd.DataFrame({"time_taken": [time_taken],
                            "dataset": [cfg.dataset_name], 
                            "partitioning": [cfg.partitioning],
                            "num_clients": [cfg.num_clients],
                            "clients_per_round": [cfg.clients_per_round],
                            "num_rounds": [cfg.num_rounds],
                            "approach": [cfg.approach],
                            "alpha_r": [cfg.alpha_r],
                            "alpha": [cfg.alpha],
                            "local_epochs": [cfg.num_epochs], 
                            "sparsify_by": [cfg.sparsify_by], 
                            "learning_rate": [cfg.learning_rate], 
                            "losses": [history.losses_centralized],
                            "accs": [history.metrics_centralized["accuracy"]],
                            "grad_norm_sq": [history.metrics_distributed_fit["grad_norm_sq"]],
                            "residual_energy": [history.metrics_distributed_fit["residual_energy"]],
                            "grad_mismatch_sq": [history.metrics_distributed_fit["grad_mismatch_sq"]],
                            "rho_r": [history.metrics_distributed_fit["rho_r"]],
                            "uplink_bits_total": [history.metrics_distributed_fit["uplink_bits_total"]],
                        })
    
    try:
        from hydra.utils import to_absolute_path
        save_path = Path(to_absolute_path(cfg.save_path))
    except Exception:
        # Fallback if not running under Hydra
        save_path = Path(cfg.save_path)

    # 1) Ensure the directory exists
    save_path.parent.mkdir(parents=True, exist_ok=True)

    # 2) Write, adding header only for a new/empty file
    header_needed = not save_path.exists() or save_path.stat().st_size == 0
    results.to_csv(save_path, mode="a", index=False, header=header_needed)

    # 8. Close wandb
    if cfg.wandb == True:
        wandb.finish()
    
if __name__ == "__main__":
    main()