"""Defines the client class and support functions for FedAvg."""

from typing import Callable, Dict, List, OrderedDict

import flwr as fl
import torch
from flwr.common import Scalar, Context
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
import numpy as np

from utils_helper import cosine_eta_warmfloor
from models import test, train_fedavg, train_fedavg_adam
from utils_collect_res import compute_metrics_before_send, append_jsonl, _comm_bits_from_payload


# pylint: disable=too-many-instance-attributes
class FlowerClientFedAvg(fl.client.NumPyClient):
    """Flower client implementing FedAvg."""

    # pylint: disable=too-many-arguments
    def __init__(
        self,
        net: torch.nn.Module,
        trainloader: DataLoader,
        valloader: DataLoader,
        device: torch.device,
        num_epochs: int,
        learning_rate: float,
        momentum: float,
        weight_decay: float,
        context: Context,
    ) -> None:
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.device = device
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.context = context

    def get_parameters(self, config: Dict[str, Scalar]):
        """Return the current local model parameters."""
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        """Set the local model parameters using given ones."""
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict()

        for k, v in params_dict:
            try:
                if k.endswith('num_batches_tracked'):
                    tensor = torch.tensor(v, dtype=torch.long)
                else:
                    tensor = torch.Tensor(v)
                state_dict[k] = tensor
            except Exception as e:
                print(f"Error converting parameter {k}: {e}")

        if not state_dict:
            raise ValueError("State dictionary is empty!")

        try:
            self.net.load_state_dict(state_dict, strict=True)
        except Exception as e:
            print(f"Error loading state dict with strict=True: {e}")
            self.net.load_state_dict(state_dict, strict=False)

    def fit(self, parameters, config: Dict[str, Scalar]):
        """Implement distributed fit function for a given client for FedAvg."""
        self.set_parameters(parameters)
        # cosine scheduling for LR
        eta_max = 0.01
        eta_min = 0.003
        learning_rate = cosine_eta_warmfloor(
            r=config["server_round"], total_R=400,
            eta_max=eta_max, eta_min=eta_min, warmup_R=10
        )
        train_fedavg(
            self.net,
            self.trainloader,
            self.device,
            self.num_epochs,
            learning_rate,
            self.momentum,
            self.weight_decay,
        )
        final_p_np = self.get_parameters({})

        uplink_bits_per_client = sum(arr.nbytes for arr in final_p_np) / (1024**2)
        return final_p_np, len(self.trainloader.dataset), {"uplink_bits_total": uplink_bits_per_client}

    def evaluate(self, parameters, config: Dict[str, Scalar]):
        """Evaluate using given parameters."""
        self.set_parameters(parameters)
        loss, acc = test(self.net, self.valloader, self.device)
        return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}


# pylint: disable=too-many-arguments
def gen_client_fn(
    trainloaders: List[DataLoader],
    valloaders: List[DataLoader],
    num_epochs: int,
    learning_rate: float,
    model: DictConfig,
    momentum: float = 0.9,
    weight_decay: float = 1e-5,
) -> Callable[[str], FlowerClientFedAvg]:  # pylint: disable=too-many-arguments
    """Generate the client function that creates the FedAvg flower clients.

    Parameters
    ----------
    trainloaders: List[DataLoader]
        A list of DataLoaders, each pointing to the dataset training partition
        belonging to a particular client.
    valloaders: List[DataLoader]
        A list of DataLoaders, each pointing to the dataset validation partition
        belonging to a particular client.
    num_epochs : int
        The number of local epochs each client should run the training for before
        sending it to the server.
    learning_rate : float
        The learning rate for the SGD  optimizer of clients.
    momentum : float
        The momentum for SGD optimizer of clients
    weight_decay : float
        The weight decay for SGD optimizer of clients

    Returns
    -------
    Callable[[str], FlowerClientFedAvg]
        The client function that creates the FedAvg flower clients
    """

    def client_fn(context:Context) -> FlowerClientFedAvg:
        """Create a Flower client representing a single organization."""
        # Load model
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        net = instantiate(model).to(device)

        # Note: each client gets a different trainloader/valloader, so each client
        # will train and evaluate on their own unique data
        trainloader = trainloaders[int(context.node_config["partition-id"])]
        valloader = valloaders[int(context.node_config["partition-id"])]

        return FlowerClientFedAvg(
            net,
            trainloader,
            valloader,
            device,
            num_epochs,
            learning_rate,
            momentum,
            weight_decay,
            context,
        ).to_client()

    return client_fn
