#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide client training and server evaluation utilities for federated learning."""

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from sklearn import metrics
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

# Type alias for a model's state dictionary
StateDict = Dict[str, torch.Tensor]


class DatasetSplit(Dataset):
    """Expose a dataset subset through the standard ``Dataset`` interface."""

    def __init__(self, dataset: Dataset, indices: List[int]):
        """
        Args:
            dataset: The complete dataset.
            indices: A list of indices that define the subset.
        """
        self.dataset = dataset
        self.indices = indices

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, item: int) -> Tuple[torch.Tensor, int]:
        """Return the feature and label pair located at the requested subset index."""
        image, label = self.dataset[self.indices[item]]
        return image, label


class ClientUpdater:
    """Manage the local optimisation routine for a single federated client."""

    def __init__(
        self,
        device: torch.device,
        dataset: Dataset,
        indices: List[int],
        batch_size: int,
        logger: Optional[SummaryWriter] = None,
    ):
        """
        Args:
            device: The device to train on ('cpu' or 'cuda').
            dataset: The full training dataset.
            indices: The indices of the data assigned to this client.
            batch_size: The local batch size for training.
            logger: An optional TensorBoard logger.
        """
        self.device = device
        self.logger = logger
        self.loss_fn = nn.CrossEntropyLoss()

        self.has_data = len(indices) > 0
        if self.has_data:
            self.train_loader, self.test_loader = self._create_data_loaders(dataset, indices, batch_size)

    def _create_data_loaders(
        self, dataset: Dataset, indices: List[int], batch_size: int
    ) -> Tuple[DataLoader, DataLoader]:
        """Construct train and evaluation data loaders for the client's partition."""
        train_dataset = DatasetSplit(dataset, indices)
        test_dataset = DatasetSplit(dataset, indices)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=len(indices), shuffle=False)
        return train_loader, test_loader

    def train(self, model: nn.Module, learning_rate: float, local_epochs: int) -> Tuple[StateDict, float, float]:
        """
        Run local training on the client's data partition.

        Args:
            model: Model instance to optimise locally.
            learning_rate: Learning rate used by the optimiser.
            local_epochs: Number of local epochs to execute.

        Returns:
            Tuple of (updated weights, mean training loss, evaluation accuracy).
        """
        if not self.has_data:
            return model.state_dict(), 0.0, 0.0

        model.to(self.device)
        model.train()

        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5)
        epoch_losses = []

        for _ in range(local_epochs):
            batch_losses = []
            for images, labels in self.train_loader:
                images, labels = images.to(self.device), labels.to(self.device)

                model.zero_grad()
                logits = model(images)
                loss = self.loss_fn(logits, labels)
                loss.backward()

                # --- FIX ---
                # Add gradient clipping to prevent exploding gradients.
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)

                optimizer.step()

                if self.logger:
                    self.logger.add_scalar("loss", loss.item())
                batch_losses.append(loss.item())

            # Avoid division by zero if a batch is empty
            if batch_losses:
                epoch_losses.append(sum(batch_losses) / len(batch_losses))

        accuracy, _ = self.evaluate(model)
        avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0

        return model.state_dict(), avg_loss, accuracy

    def evaluate(self, model: nn.Module) -> Tuple[float, float]:
        """
        Evaluate the provided model on the client's local validation split.

        Args:
            model: Model instance to evaluate.

        Returns:
            Tuple containing (accuracy, loss).
        """
        if not self.has_data:
            return 0.0, 0.0

        model.to(self.device)
        model.eval()

        total_loss = 0.0
        all_labels = []
        all_preds = []

        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                logits = model(images)
                total_loss += self.loss_fn(logits, labels).item()

                preds = torch.argmax(logits, dim=1)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        loss = total_loss / len(self.test_loader)
        accuracy = metrics.accuracy_score(y_true=all_labels, y_pred=all_preds)
        return accuracy, loss


class ServerEvaluator:
    """Coordinate server-side evaluation across class-specific partitions."""

    def __init__(
        self,
        device: torch.device,
        dataset: Dataset,
        indices: List[int],
        batch_size: int,
    ):
        """
        Args:
            device: The device for evaluation ('cpu' or 'cuda').
            dataset: The evaluation dataset.
            indices: The indices of the dataset to be used for evaluation.
            batch_size: The batch size for evaluation.
        """
        self.device = device
        self.loss_fn = nn.CrossEntropyLoss()
        self.class_loaders = self._create_class_based_loaders(dataset, indices, batch_size)

    def _create_class_based_loaders(
        self, dataset: Dataset, indices: List[int], batch_size: int
    ) -> Dict[int, DataLoader]:
        """Partition the evaluation dataset by class and materialise loaders."""
        try:
            labels = np.array(dataset.targets)
        except AttributeError:
            labels = np.array([label for _, label in dataset])

        unique_classes = np.unique(labels[indices])

        indices_by_class = {cls: [] for cls in unique_classes}
        for idx in indices:
            label = labels[idx]
            if label in indices_by_class:
                indices_by_class[label].append(idx)

        class_loaders = OrderedDict()
        for cls, cls_indices in indices_by_class.items():
            if not cls_indices:
                continue
            class_dataset = DatasetSplit(dataset, cls_indices)
            class_loaders[cls] = DataLoader(class_dataset, batch_size=batch_size, shuffle=False)
        return class_loaders

    def evaluate_by_class(self, model: nn.Module) -> Tuple[List[float], List[float]]:
        """
        Measure per-class performance for the supplied model.

        Args:
            model: Model instance to evaluate.

        Returns:
            Tuple of (accuracies, losses) aligned with the class ordering.
        """
        model.to(self.device)
        model.eval()

        all_accuracies = []
        all_losses = []

        for cls, loader in self.class_loaders.items():
            class_loss = 0.0
            all_labels = []
            all_preds = []
            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    logits = model(images)
                    class_loss += self.loss_fn(logits, labels).item()

                    preds = torch.argmax(logits, dim=1)
                    all_labels.extend(labels.cpu().numpy())
                    all_preds.extend(preds.cpu().numpy())

            if not all_labels:
                continue
            accuracy = metrics.accuracy_score(y_true=all_labels, y_pred=all_preds)
            loss = class_loss / len(loader)
            all_accuracies.append(accuracy)
            all_losses.append(loss)

        return all_accuracies, all_losses
