# FedAvg strategy modified to track client IDs using client metrics.

import utils
import copy
import torch
import os
import numpy as np
import ray
import argparse
import random
from logging import WARNING, INFO
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from flwr.common.logger import log

# Import the new data preprocessing method
from dataset.dataset import create_federated_dataloaders
import model

def sample_unlearning_clients(available_clients, sample_size, args):
    """Sample client IDs for unlearning."""
    unlearning_clients = args.unlearning_clients
    sampled_clients = unlearning_clients
    while any(client_idx in unlearning_clients for client_idx in sampled_clients):
        if args.method == 'FATS':
            sampled_clients = random.choices(available_clients, k=sample_size)
            sampled_clients = list(set(sampled_clients))
        else:
            sampled_clients = random.sample(available_clients, sample_size)
    return sampled_clients

def sample_clients(available_clients, sample_size, args):
    """Sample client IDs."""
    if args.method == 'FATS':
        sampled_clients = random.choices(available_clients, k=sample_size)
        sampled_clients = list(set(sampled_clients))
    else:
        sampled_clients = random.sample(available_clients, sample_size)
    return sampled_clients

class FedAvgSameClients(FedAvg):
    """FedAvg strategy modified to track client IDs using client metrics."""

    def __init__(
        self,
        *,
        args: argparse.Namespace,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int, int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        if initial_parameters is None:
            raise ValueError("initial_parameters must be provided and not None.")
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.args = args

        # Include round 0 in the records
        self.clients_records = {i: None for i in range(0, args.num_rounds + 1)}
        self.parameters_records = {i: None for i in range(0, args.num_rounds + 1)}
        self._current_round_clients: List[Tuple[ClientProxy, FitIns]] = []
        # Store initial parameters at round 0 as NumPy arrays
        self.parameters_records[0] = parameters_to_ndarrays(initial_parameters)
        # Initialize clients_records[0] with an empty list
        self.clients_records[0] = []
        self.reset_parameters = parameters_to_ndarrays(initial_parameters)
        # Initialize unlearning tracking
        # Ensure unlearning_clients are strings
        self.unlearning_clients = [str(cid) for cid in args.unlearning_clients]
        self.unlearning_rounds = args.unlearning_rounds  # Now a list
        self.unlearned_clients = set()  # Track which clients have been unlearned
        # Initialize client ID mapping
        self.client_dataset_map = {}  # Maps Flower client IDs to dataset IDs

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        # Seed the random module
        random.seed(self.args.random_seed + server_round)

        config = {}
        if self.on_fit_config_fn is not None:
            # Pass only the current round's unlearning status
            is_unlearning_round = any(round_num == server_round for round_num in self.unlearning_rounds)
            config = self.on_fit_config_fn(server_round, is_unlearning_round)

        # **Update unlearned clients before sampling**
        for client_id, round_num in zip(self.unlearning_clients, self.unlearning_rounds):
            if round_num <= server_round and client_id not in self.unlearned_clients:
                self.unlearned_clients.add(client_id)
                log(INFO, f"Marked client {client_id} as unlearned by round {server_round}")

        # Log current unlearned clients for debugging
        log(INFO, f"Currently unlearned clients: {self.unlearned_clients}")

        fit_ins = FitIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        client_manager.wait_for(min_num_clients)

        # Get available client IDs
        available_cids = list(client_manager.clients.keys())
        
        # Log available clients before filtering
        log(INFO, f"Available clients before filtering: {available_cids}")

        # Filter out clients whose dataset IDs are in unlearned_clients
        available_cids = [
            cid for cid in available_cids 
            if self.client_dataset_map.get(cid) not in self.unlearned_clients
        ]
        
        # Log available clients after filtering
        log(INFO, f"Available clients after filtering: {available_cids}")

        # Sample clients normally
        sampled_cids = sample_clients(available_cids, sample_size, self.args)
        
        # Log sampled clients
        log(INFO, f"Sampled clients: {sampled_cids}")

        # Create FitIns for each client
        clients = [client_manager.clients[cid] for cid in sampled_cids]
        fit_clients = [(client, fit_ins) for client in clients]
        self._current_round_clients = fit_clients

        return fit_clients

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average and handle unlearning."""
        if not results:
            log(INFO, "No results to aggregate in round %s.", server_round)
            return self.initial_parameters, {}
        if not self.accept_failures and failures:
            return None, {}

        # Convert results to weights and dataset sizes
        weights_results = []
        sampled_client_ids = []
        for client_proxy, fit_res in results:
            num_examples = fit_res.num_examples
            weights = parameters_to_ndarrays(fit_res.parameters)
            weights_results.append((weights, num_examples))

            # Retrieve dataset_id from metrics and update mapping
            dataset_id = fit_res.metrics.get('dataset_id')
            if dataset_id is not None:
                # Ensure dataset_id is a string
                dataset_id = str(dataset_id)
                sampled_client_ids.append(dataset_id)
                # Update the mapping
                self.client_dataset_map[client_proxy.cid] = dataset_id
            else:
                log(WARNING, "Client %s did not provide a dataset_id.", client_proxy.cid)

        # Save the sampled client IDs and parameters for this round
        if server_round <= self.args.num_rounds:
            self.clients_records[server_round] = sampled_client_ids
            
            # Aggregate parameters normally first
            aggregated_parameters = aggregate(weights_results)
            # Store the NumPy arrays directly
            self.parameters_records[server_round] = aggregated_parameters

            log(
                INFO,
                "Round %s: clients_records %s",
                server_round,
                self.clients_records[server_round],
            )

        # **Perform unlearning parameter restoration if needed**
        clients_to_unlearn = []
        for client_id, round_num in zip(self.unlearning_clients, self.unlearning_rounds):
            if round_num == server_round:
                clients_to_unlearn.append(client_id)

        if clients_to_unlearn:
            log(INFO, f"Unlearning clients {clients_to_unlearn} at round {server_round}")

            # Find earliest round where any of these clients participated
            earliest_rounds = {}  # Track earliest round for each client
            for client_id in clients_to_unlearn:
                client_earliest_round = None
                # Search forward from round 0 to find first participation
                for round_num in range(0, server_round):
                    client_ids = self.clients_records.get(round_num, [])
                    # Ensure client_ids are strings
                    client_ids = [str(cid) for cid in client_ids]
                    if client_id in client_ids:
                        client_earliest_round = round_num
                        break  # Found first participation for this client
                if client_earliest_round is not None:
                    earliest_rounds[client_id] = client_earliest_round
                    log(INFO, "Client %s first participated in round %s", client_id, client_earliest_round)

            if earliest_rounds:
                # Get the overall earliest round across all clients being unlearned
                overall_earliest_round = min(earliest_rounds.values())
                log(INFO, "Determined earliest affected round: %s", overall_earliest_round)

                # Get parameters from the round before earliest participation
                restore_round = max(0, overall_earliest_round - 1)
                log(INFO, "Restoring parameters from round: %s", restore_round)
                
                parameters_arrays = self.parameters_records.get(restore_round)
                if parameters_arrays is not None:
                    if self.args.unlearning_samples not in [None, 'None']:
                        log(INFO, "Performing sample-level unlearning")
                        # Recover parameters at sample level
                        weights_results_recovered = self.recover_parameters_at_sample_level(
                            parameters_arrays,
                            clients_to_unlearn,
                            self.args.unlearning_samples,
                            self.args
                        )
                        aggregated_parameters = aggregate(weights_results_recovered)
                    else:
                        log(INFO, "Performing client-level unlearning")
                        aggregated_parameters = parameters_arrays

                    # Store the NumPy arrays directly
                    self.parameters_records[server_round] = aggregated_parameters

            else:
                log(INFO, "No rounds affected.")

        # Get the final parameters for this round and convert to Parameters type
        parameters_arrays = self.parameters_records[server_round]
        if parameters_arrays is not None:
            parameters_aggregated = ndarrays_to_parameters(parameters_arrays)
        else:
            parameters_aggregated = ndarrays_to_parameters(self.reset_parameters)

        # Aggregate custom metrics
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        evaluate_config = []
        for client_proxy, fit_ins in self._current_round_clients:
            eval_ins = EvaluateIns(fit_ins.parameters, fit_ins.config)
            evaluate_config.append((client_proxy, eval_ins))
        return evaluate_config

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        # Aggregate loss
        loss_aggregated = weighted_loss_avg(
            [(res.num_examples, res.loss) for _, res in results]
        )

        # Aggregate custom metrics
        metrics_aggregated = {}
        if self.evaluate_metrics_aggregation_fn:
            eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
            log(
                INFO,
                "server_round %s, metrics_aggregated %s",
                str(server_round),
                str(metrics_aggregated),
            )
        elif server_round == 1:
            log(WARNING, "No evaluate_metrics_aggregation_fn provided")

        return loss_aggregated, metrics_aggregated

    def recover_parameters_at_sample_level(self, parameters, unlearning_clients, unlearning_samples, args):
        """
        Recover parameters at sample level by retraining the model without unlearning samples.
        """
        net_init = model.return_model(args.dataset, args.num_classes)
        utils.set_parameters(net_init, ndarrays_to_parameters(parameters))

        # Load the data using the new data preprocessing pipeline
        config = args  # Assuming args contains the necessary configuration attributes

        # Create federated dataloaders
        client_loaders, eval_loader = create_federated_dataloaders(
            config=config,
            dataset=args.dataset,
            sampling_type=args.distribution_type,
            dataset_fraction=args.dataset_fraction,
            batch_size=args.batch_size,
            random_seed=args.random_seed,
            method=args.method,
        )

        # Ensure client IDs in client_loaders are strings
        client_loaders = {str(cid): loader for cid, loader in client_loaders.items()}

        # Adjust client loaders to remove unlearning samples
        for cid in unlearning_clients:
            if cid in client_loaders:
                train_dataset = client_loaders[cid]['train'].dataset
                # Remove the unlearning samples from the client's dataset
                partition_size = len(train_dataset)
                num_samples_to_remove = int(partition_size * unlearning_samples)

                if num_samples_to_remove > 0:
                    # Keep only the samples after the ones to be unlearned
                    remaining_indices = list(range(num_samples_to_remove, partition_size))
                    new_subset = torch.utils.data.Subset(train_dataset, remaining_indices)
                    client_loaders[cid]['train'] = torch.utils.data.DataLoader(
                        new_subset,
                        batch_size=args.batch_size,
                        shuffle=True
                    )

        # Retrain the models for each client
        parameters_list = []
        for c_idx in client_loaders.keys():
            net = copy.deepcopy(net_init)
            train_loader = client_loaders[c_idx]['train']
            model.train(
                net,
                train_loader,
                [],
                method=args.method,
                epochs=args.epochs_per_round,
                learning_rate=args.learning_rate,
                device=args.device,
                n_batches=args.batches_per_round,
                server_round=0,
                unlearning_round=1,
            )
            parameters_list.append((utils.get_parameters(net), len(train_loader.dataset)))

        # Aggregate the parameters
        weights_results = parameters_list
        return weights_results
