import concurrent.futures
import copy
import os
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from flwr.server.strategy.fedavg import FedAvg
from flwr.server.strategy.strategy import Strategy
from sklearn import decomposition

from src.simplex import SolutionSimplex
from src.client import flwr_set_parameters, flwr_get_parameters
from src.networks import net_fn
from src.simplex_layers import SimplexConv, SimplexLinear
from src.util import compute_riesz_s_energy, projection_simplex

MARKERS = ["o", "1", "*", "2", "x", ".", "<", ">", "2", "3"] * 10
COLORS = ["r", "b", "g", "c", "m", "y", "brown", "purple", "cyan", "w"] * 10

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""


class FedAvgCustom(FedAvg):

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        parameters_aggregated, metrics_aggregated = super().aggregate_fit(server_round, results, failures)
        # if server_round % 10 == 0:
        #     weights_results = [
        #         (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
        #         for _, fit_res in results
        #     ]
        #     gradient_signal_to_noise_ratio = gradient_total_variance([w[-1] for w, _ in weights_results])
        #     metrics_aggregated["grad_var"] = gradient_signal_to_noise_ratio
        return parameters_aggregated, metrics_aggregated


class FLOCO(Strategy):
    """Configurable FedAvg strategy implementation."""

    # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        xp_name: str = None,
        cfg=None,
        subspace=None,
        writer=None,
    ) -> None:
        super().__init__()

        if min_fit_clients > min_available_clients or min_evaluate_clients > min_available_clients:
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.evaluate_fn = evaluate_fn
        self.on_fit_config_fn = on_fit_config_fn
        self.on_evaluate_config_fn = on_evaluate_config_fn
        self.accept_failures = accept_failures
        self.initial_parameters = initial_parameters
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn

        self.xp_name = xp_name
        self.cfg = cfg
        self.subspace = subspace
        self.current_weights = parameters_to_ndarrays(initial_parameters)
        self.writer = writer

        self.solution_simplex = SolutionSimplex(cfg=self.cfg, xp_name=self.xp_name)
        self.client_gradient_dict = {}

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return the sample size and the required number of available clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients

    def initialize_parameters(self, client_manager: ClientManager) -> Optional[Parameters]:
        """Initialize global model parameters."""
        initial_parameters = self.initial_parameters
        self.initial_parameters = None  # Don't keep initial parameters in memory
        return initial_parameters

    def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        alphas = []
        if self.cfg.strategy.subspace_start <= server_round:
            alphas = []
            for _, tmp_alpha in enumerate(self.centers):
                alphas.append(tmp_alpha)

        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = flwr_get_parameters(self.subspace)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {'centers': alphas})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        self.selected_client_cids = [client.cid for client in clients]

        if self.cfg.strategy.subspace_start == server_round:
            unsorted_cluster_labels = []  # TODO what does this mean?
            for k, _ in self.client_gradient_dict.items():
                cluster_id = int(k.split('_')[0]) % 5
                unsorted_cluster_labels.append(cluster_id)
            unsorted_cluster_labels = np.array(unsorted_cluster_labels)
            unsorted_client_statistics = np.array([v for _, v in self.client_gradient_dict.items()])

            # IMPORTANT STEP: SORT THE CLIENT STATISTICS !!!
            unsorted_client_ids = np.array([int(k.split('_')[0]) for k in self.client_gradient_dict.keys()])
            sorted_args = np.argsort(unsorted_client_ids)
            client_statistics = unsorted_client_statistics[sorted_args]
            cluster_labels = unsorted_cluster_labels[sorted_args]

            # Projection
            print(' ... PROJECTING GRADIENTS ... ')
            if self.cfg.save_clustering_stats:
                os.makedirs(f"final_runs/sampling_region_plots/{self.xp_name}", exist_ok=True)
                np.save(
                    "final_runs/sampling_region_plots/" + self.xp_name + '/raw_client_gradients.npy',
                    self.client_gradient_dict
                )
                
            # Dim. reduction
            client_statistics = decomposition.PCA(
                n_components=self.cfg.rule.num_points,
            ).fit_transform(client_statistics)
            print('... finished PCA')
                
            # Offset parameter optimization
            statistics_over_z = []
            energies_over_z = []
            best_z = None
            last_log_energy = np.inf
            for i, z in enumerate(np.linspace(self.cfg.dataset_model.z_interval[0],self.cfg.dataset_model.z_interval[1],1000)):
                # 2. Optimized Simplex projection
                final_client_statistics = projection_simplex(client_statistics, z=z, axis=1)
                final_client_statistics /= final_client_statistics.sum(1).reshape(-1,1)
                statistics_over_z.append(final_client_statistics)
                _, log_energy = compute_riesz_s_energy(final_client_statistics, d=2)
                if log_energy not in [-np.inf, np.inf]:
                    energies_over_z.append(log_energy)
                    if log_energy < last_log_energy:
                        best_z = i
                        last_log_energy = log_energy
            print('... finished parameter optimization')

            client_statistics = np.array(statistics_over_z)[best_z]
            if self.cfg.save_clustering_stats:
                np.save(
                    "final_runs/sampling_region_plots/" + self.xp_name + '/final_client_gradients.npy',
                    client_statistics
                )
                np.save(
                    "final_runs/sampling_region_plots/" + self.xp_name + '/cluster_labels.npy',
                    cluster_labels
                )
            self.solution_simplex.set_solution_simplex_regions(
                projected_points=client_statistics,
                rho=self.cfg.rule.rho,
            )
            self.centers = [simplex_region.center_simplex for simplex_region in self.solution_simplex.simplex_regions]

        elif self.cfg.strategy.subspace_start == (server_round + 1):
            # Sample all clients to get most up to date gradient, or simplex information
            self.last_selected_client_cids = copy.deepcopy(self.selected_client_cids)
            clients = client_manager.sample(num_clients=self.cfg.num_clients, min_num_clients=self.cfg.num_clients)
            self.selected_client_cids = [client.cid for client in clients]

        fit_ins_all, train_alphas = self.sample_client_weights(
            server_round=server_round,
            parameters=parameters,
            client_ids=self.selected_client_cids,
            train=True
        )
        self.all_alphas = train_alphas

        # Return client/config pairs
        return [(client, fit_ins_all[i]) for i, client in enumerate(clients)]

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        # Do not configure federated evaluation if fraction eval is 0.
        if self.fraction_evaluate == 0.0:
            return []
        # Parameters and config
        config = {}
        if self.on_evaluate_config_fn is not None:
            # Custom evaluation config function provided
            config = self.on_evaluate_config_fn(server_round)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(client_manager.num_available())
        clients = client_manager.sample(
            num_clients=sample_size,
            min_num_clients=min_num_clients,
        )

        self.selected_client_cids = [client.cid for client in clients]

        evaluate_ins_all, test_alphas = self.sample_client_weights(
            server_round=server_round,
            parameters=parameters,
            client_ids=self.selected_client_cids,
            train=False,
        )
        self.all_alphas = test_alphas

        return [(client, evaluate_ins_all[i]) for i, client in enumerate(clients)]

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        if self.cfg.strategy.subspace_start == (server_round + 1):
            # Convert results
            weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results]
            weight_result_cids = np.array([c.cid for c, _ in results])
            new_weight_results = []
            for cid in self.last_selected_client_cids:
                arg_id = np.argwhere( cid ==  np.array(weight_result_cids))[0][0]
                client_weight_results = weights_results[arg_id]
                new_weight_results.append(client_weight_results)
            weights_results = new_weight_results
               
            # Save client gradients/losses for later clustering
            for cl, fit_res in results:
                client_id = cl.cid
                w = parameters_to_ndarrays(fit_res.parameters)
                client_grads = [w[-i].flatten() for i in range(1,self.cfg.rule.num_points+1)]
                client_grads = np.concatenate(client_grads)
                self.client_gradient_dict[client_id] = client_grads
        else:
            # Convert results
            weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results]
        
        delta_t = aggregate(weights_results)
        tmp_agg_params: NDArrays = [
            x - y * self.cfg.alpha_g for x, y in zip(flwr_get_parameters(self.subspace), delta_t)
        ]
        flwr_set_parameters(self.subspace, tmp_agg_params)
        parameters_aggregated = ndarrays_to_parameters(tmp_agg_params)
        self.current_weights = parameters_aggregated

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        # Compute Weight Delta Variances
        if self.cfg.grad_var and server_round % self.cfg.eval_freq == 0:
            gradient_signal_to_noise_ratio = []
            for i in range(1,self.cfg.rule.num_points+1):
                tmp_layers = [w[-i] for w, _ in weights_results]
                tmp_gradient_signal_to_noise_ratio = gradient_total_variance(
                    tmp_layers
                )
                gradient_signal_to_noise_ratio.append(tmp_gradient_signal_to_noise_ratio)
            # -2 index corresponds to fedavg endpoint with same seed
            # drift_diversity = drift_diversities[-1]
            gradient_signal_to_noise_ratio = np.mean(gradient_signal_to_noise_ratio)
            metrics_aggregated["grad_var"] = gradient_signal_to_noise_ratio

        return parameters_aggregated, metrics_aggregated

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}
        # Aggregate loss
        loss_aggregated = weighted_loss_avg(
            [(evaluate_res.num_examples, evaluate_res.loss) for _, evaluate_res in results]
        )
        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.evaluate_metrics_aggregation_fn:
            eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No evaluate_metrics_aggregation_fn provided")
        return loss_aggregated, metrics_aggregated

    def get_simple_net_params(self, conv_comb_weights):
        new_net = net_fn(
            dataset_name=self.cfg.dataset_model.dataset_name,
            num_classes=self.cfg.dataset_model.num_classes,
            network_arch=self.cfg.dataset_model.network_arch,
            network_type="POINT",
            num_points=1,
            seed=self.cfg.seed,
            device=self.cfg.device,
        )
        flwr_set_parameters(new_net, parameters=conv_comb_weights)
        return flwr_get_parameters(new_net)

    def sample_params_line(self, tmp_client_label_dist, uniform=True, train=True, idx=0):
        if uniform:
            tmp_train_task_id = 0
            alpha = np.random.uniform(0.0, 1.0)
        else:
            # tmp_label_dist_lda = self.lda.transform(tmp_client_label_dist.reshape(1, -1))
            tmp_label_dist_lda = tmp_client_label_dist.reshape(1, -1)

            tmp_train_task_id = self.cluster_model.predict(tmp_label_dist_lda)[0]

            splits = np.array_split(np.arange(0, 1, 0.1), self.cfg.rule.rule_arg)
            sampling_regions = []
            for region in splits:
                sampling_regions.append([region[0], region[-1] + 0.1])
            tmp_sampling_region = sampling_regions[tmp_train_task_id]
            
            # At train we always sample uniformly from cluster region
            if train:
                alpha = np.random.uniform(*tmp_sampling_region)
            else:
                # If testing, the first sample should always be the cluster region center
                if idx == 0:
                    alpha = tmp_sampling_region[0] + ((tmp_sampling_region[1] - tmp_sampling_region[0]) / 2)
                else:
                    alpha = np.random.uniform(*tmp_sampling_region)

        return alpha, tmp_train_task_id
    
    def sample_params_simplex(self, cfg, client_id_str, uniform=True, train=True):
        client_id = int(client_id_str.split('_')[0])
        # Sample uniformly from simplex
        if uniform:
            alpha_simplex, alpha_cartesian, sampled_region_id = self.solution_simplex.sample_uniform(client_id)
        else:
            # If testing, sample from the center of the simplex subregion
            if not train:
                alpha_simplex, alpha_cartesian, sampled_region_id = self.solution_simplex.get_client_center(client_id)
            else:
                alpha_simplex, alpha_cartesian, sampled_region_id = self.solution_simplex.get_client_subregion(client_id)

        return alpha_simplex, sampled_region_id, alpha_cartesian

    def sample_client_weights(self, server_round, parameters, client_ids, train=True):
        uniform = False if self.cfg.strategy.subspace_start <= server_round else True
        config = {"server_round": server_round}
        # Initialize the subspace to sample from it
        net = net_fn(
            dataset_name=self.cfg.dataset_model.dataset_name,
            num_classes=self.cfg.dataset_model.num_classes,
            network_arch=self.cfg.dataset_model.network_arch,
            network_type=self.cfg.strategy.network_type,
            num_points=self.cfg.rule.num_points,
            seed=self.cfg.seed,
            device=self.cfg.device,
        )
        flwr_set_parameters(net, parameters=parameters_to_ndarrays(parameters))
        ins_all = []
        client_alphas = []
        # Sample clients, assign specific alpha region to it
        client_alphas = {client_id_str: [] for client_id_str in client_ids}

        with concurrent.futures.ThreadPoolExecutor(max_workers=None) as executor:
            submitted_fs = {
                executor.submit(self.sample_params_simplex, self.cfg, client_id_str, uniform=uniform, train=train)
                for client_id_str in client_ids
            }
            finished_fs, _ = concurrent.futures.wait(
                fs=submitted_fs,
                timeout=None,
            )
        for future in finished_fs:
            client_alpha_simplex, client_train_task_id, client_alpha_cartesian = future.result()
            client_alpha_simplex = client_alpha_simplex.tolist()  # required for grcp
            client_alphas[f'{client_train_task_id}_client'].append(client_alpha_simplex)

        tmp_subspace_params = ndarrays_to_parameters(flwr_get_parameters(self.subspace))
        for client_id in client_ids:
            _client_config = copy.deepcopy(config)
            _client_config["alpha"] = client_alphas[client_id][0]
            ins_all.append(FitIns(tmp_subspace_params, _client_config))
    
        return ins_all, client_alphas

def compute_gradient_signal_to_noise_ratio(weight_results):
    flattenend_weights = [weight.flatten() for weight in weight_results]
    flattenend_weights_tensor = np.array(flattenend_weights).transpose(0,1)
    mean_vec = flattenend_weights_tensor.mean(axis=1) 
    # print(f'flattenend_weights_tensor.shape {flattenend_weights_tensor.shape}')
    # print(f'mean_vec.shape {mean_vec.shape}')
    sq_mean = mean_vec.T @ mean_vec
    cov = np.cov(flattenend_weights_tensor)
    variance = np.diag(cov).sum()
    snr = sq_mean / variance
    return snr

def gradient_total_variance(weight_results):
    flattenend_weights = [weight.flatten() for weight in weight_results]
    flattenend_weights_tensor = np.array(flattenend_weights).transpose(1,0)
    cov = np.cov(flattenend_weights_tensor)
    nom_diag = np.diag(cov)
    nom = np.sum(nom_diag)
    return nom

def is_weight_module(module):
    return isinstance(module, (
        torch.nn.Conv2d,
        torch.nn.BatchNorm2d,
        torch.nn.Linear,
        torch.nn.LSTM,
        torch.nn.Embedding,
    ))


def is_subspace_module(module):
    return isinstance(module, (
        SimplexConv,
        SimplexLinear,
    ))
