"""Client implementation for federated learning."""

import hashlib
from typing import Dict, List, Tuple, Callable
import flwr as fl
import torch
from flwr.common.typing import NDArrays, Scalar, Properties
from torch.utils.data import DataLoader

from model import return_model, test, train
from utils import get_parameters, set_parameters

class FlowerClient(fl.client.NumPyClient):
    """Flower client for training with train and validation loss and accuracy."""

    def __init__(
        self,
        cid: str,
        net: torch.nn.Module,
        trainloader: DataLoader,
        valloader: DataLoader,
        testloader: DataLoader,
        device: torch.device,
        num_epochs: int,
        learning_rate: float,
        num_batches: int = None,
        method: str = 'base',
        dataset_id: int = None,
    ) -> None:
        self.cid = cid
        self.device = device
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.testloader = testloader
        self.method = method
        self.num_epochs = num_epochs or len(trainloader)
        self.learning_rate = learning_rate
        self.num_batches = num_batches
        self.dataset_id = dataset_id

    def get_properties(self, config: Dict[str, Scalar]) -> Properties:
        """Return client properties."""
        return {"dataset_id": self.dataset_id}

    def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
        """Get model parameters."""
        self.net.to("cpu")
        return get_parameters(self.net)

    def fit(
        self,
        parameters: NDArrays,
        config: Dict[str, Scalar]
    ) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
        """Fit model on local training data."""
        self.net.to("cpu")
        set_parameters(self.net, parameters)

        train_loss, train_acc, val_loss, val_acc = train(
            self.net,
            self.trainloader,
            self.valloader,
            method=self.method,
            epochs=self.num_epochs,
            learning_rate=self.learning_rate,
            device=self.device,
            n_batches=self.num_batches,
            server_round=config["current_round"],
            unlearning_round=config["unlearning_round"],
        )

        self.net.to("cpu")
        metrics = {
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "dataset_id": self.dataset_id,  # Include dataset_id in metrics
        }
        if val_loss is not None and val_acc is not None:
            metrics.update({
                "val_loss": val_loss,
                "val_accuracy": val_acc,
            })

        return get_parameters(self.net), len(self.trainloader.dataset), metrics

    def evaluate(
        self,
        parameters: NDArrays,
        config: Dict[str, Scalar]
    ) -> Tuple[float, int, Dict[str, Scalar]]:
        """Evaluate model on local test data."""
        self.net.to("cpu")
        set_parameters(self.net, parameters)

        loss, accuracy = test(self.net, self.testloader, method=self.method, device=self.device)

        self.net.to("cpu")
        return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}

def create_client_fn(
    trainloaders: List[DataLoader],
    valloaders: List[DataLoader],
    testloaders: List[DataLoader],
    device: torch.device = torch.device("cpu"),
    method: str = 'base',
    num_epochs: int = 0,
    learning_rate: float = 0.001,
    dataset: str = "FeMNIST",
    num_classes: int = 62,
    num_batches: int = None,
) -> Callable[[str], fl.client.Client]:
    """Create a function that will be used by Flower to create clients."""
    def client_fn(cid: str) -> FlowerClient:
        """Client creation function that assigns a unique dataset_id."""
        # Generate a unique dataset_id from cid using a hash function
        num_datasets = len(trainloaders)
        hash_object = hashlib.sha256(cid.encode('utf-8'))
        dataset_id = int(hash_object.hexdigest(), 16) % num_datasets

        # Create model
        net = return_model(dataset, num_classes)

        # Get the corresponding data loaders
        trainloader = trainloaders[dataset_id]
        valloader = valloaders[dataset_id]
        testloader = testloaders[dataset_id]

        return FlowerClient(
            cid=cid,
            net=net,
            trainloader=trainloader,
            valloader=valloader,
            testloader=testloader,
            device=device,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            num_batches=num_batches,
            method=method,
            dataset_id=dataset_id,
        )

    return client_fn
