from collections import OrderedDict
from copy import deepcopy
from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict
import torch
import torch.nn.functional as F
import wandb
import torch.nn as nn
import math


from flwr.common import (
    FitRes,
    Metrics,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)

# ------------------------------------------------------------
# helper: compute γ given the *next* server round number
# ------------------------------------------------------------
def gamma_for_round(r: int,
                    warmup_start: float = 1.00,   # γ0
                    warmup_end:   float = 0.20,   # γ∞
                    warmup_len:   int   = 30,     # when γ reaches γ∞
                    freeze_len:   int   = 10):    # keep γ0 for rounds 1…freeze_len
    """
    Piece-wise linear schedule:
        round ≤ freeze_len          : γ = γ0
        freeze_len < round ≤ warmup_len : γ linearly ↑ to γ∞
        round  > warmup_len         : γ = γ∞
    """
    if r <= freeze_len:
        return warmup_start
    if r >= warmup_len:
        return warmup_end
    # linear interpolation
    progress = (r - freeze_len) / (warmup_len - freeze_len)
    return warmup_start + progress * (warmup_end - warmup_start)

def cosine_eta_warmfloor(r, total_R, eta_max=0.01, eta_min=0.003, warmup_R=10):
    # r starts at 1
    if r <= warmup_R:
        # linear warm-up from 0 -> eta_max
        return eta_max * (r / max(1, warmup_R))
    # cosine decay with floor
    t = (r - warmup_R) / max(1, total_R - warmup_R)
    return eta_min + 0.5 * (eta_max - eta_min) * (1 + math.cos(math.pi * t))


def flatten_resnet_parameters(state_dict):
    """
    Flatten and concatenate all parameters from a state_dict into a single vector.
    """
    flat_params = torch.cat([p.view(-1).float() for p in state_dict.values()])
    # flat_params = torch.cat([p.view(-1).float() for k, p in state_dict.items()]) # if 'num_batches_tracked' not in k])
    return flat_params

def reconstruct_parameters(flat_params, shapes, sizes, trained_weights):
    """
    Reconstruct the original tensors from a flattened parameter tensor.
    """
    reconstructed_params = OrderedDict()
    offset = 0
    for key in shapes:
        # if 'num_batches_tracked' in key:
        #     continue
        num_elements = sizes[key]
        param_slice = flat_params[offset:offset + num_elements]
        reconstructed_params[key] = param_slice.view(shapes[key]).to(trained_weights[key].dtype)
        offset += num_elements
    return reconstructed_params

def segment_resnet_parameters(flat_params, num_segments):
    """
    Divide the flat parameters into equal segments based on num_segments, ensuring all elements are included.
    """
    total_len = len(flat_params)
    segment_size = total_len // num_segments  
    remainder = total_len % num_segments 

    segments = []
    start = 0
    for i in range(num_segments):
        # Calculate the end of the segment and add 1 element to some segments if there's a remainder
        end = start + segment_size + (1 if i < remainder else 0)
        segments.append(flat_params[start:end])
        start = end

    return segments


def preprocess_weights(weights: Dict):
    """
    Preprocess the weights of a model to prepare for federated learning.
    """
    flat_params_n = {key: flatten_resnet_parameters(value) for key, value in weights.items()}
    segmented_params = {key: segment_resnet_parameters(flat_param, num_segments=2) for key, flat_param in flat_params_n.items()}
    return flat_params_n, segmented_params

def extract_shared_segments(clients_dict, client_shared_segments):
    """
    Extract segments from each client's data based on the segments they are supposed to share.
    """
    shared_data = {}
    for client_id, segment_id in client_shared_segments.items():
        shared_data[client_id] = {segment_id: clients_dict[client_id][segment_id]}
    return shared_data

def aggregate_data_by_key(shared_data):
    """
    Aggregate data received from clients by key.
    """
    aggregation = {}
    count = {}

    for client_data in shared_data.values():
        for key, values in client_data.items():
            if key in aggregation:
                aggregation[key] += values
                count[key] += 1
            else:
                aggregation[key] = values.clone()
                count[key] = 1

    for key in aggregation.keys():
        aggregation[key] /= count[key]

    return aggregation

def handle_partial_updates(client_weights, client_segment_map):
    """
    Handle partial updates from clients by averaging the weights.
    """
    segmented_params = segment_resnet_parameters(flatten_resnet_parameters(client_weights), num_segments=2)
    segments_to_send = extract_shared_segments(segmented_params, client_segment_map)
    return segments_to_send

def update_client_models(clients_segmented_params, aggregated_segments, client_segment_map):
    """
    Update client models with aggregated segments and average unshared segments.
    """
    updated_params = {}

    for client_id, segments in clients_segmented_params.items():
        updated_segments = []
        shared_segment_index = client_segment_map[client_id]
        
        for i, segment_data in enumerate(segments):
            if i == shared_segment_index:
                updated_segments.append(aggregated_segments[shared_segment_index])
            else:
                updated_segments.append(segment_data)

        updated_params[client_id] = updated_segments

    updated_client_segments = {client_id: torch.cat(segments) for client_id, segments in updated_params.items()}
    return updated_client_segments


from flwr.common.typing import NDArray
from flwr.common import Array
import numpy as np

def ndarray_to_array(ndarray: NDArray):
    """Represent NumPy ndarray as Array."""
    return Array(
        data=ndarray.tobytes(),
        dtype=str(ndarray.dtype),
        stype="numpy.ndarray",  # Could be used in deserialization function
        shape=list(ndarray.shape),
    )


def basic_array_deserialisation(array: Array):
       """Deserialises array and returns NumPy array."""
       return np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """Aggregate using weighted average based on number of samples.

    :param metrics: List of tuples (num_examples, metrics)
    :return: Aggregated metrics
    """
    # Multiply accuracy of each client by number of examples used
    accuracies = np.array([num_examples * m["accuracy"] for num_examples, m in metrics])
    examples = np.array([num_examples for num_examples, _ in metrics])

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": accuracies.sum() / examples.sum()}

import copy

def combine_updates(global_params, local_residual, alpha=0.5, param_keys=None):
    """
    Combine the global parameters (a list of NumPy ndarrays) with the local residual updates.
    
    Args:
      global_params (list of np.ndarray): Global model parameters as received from the server.
      local_residual (dict): Dictionary with layer names (module names) as keys. For each layer,
                             local_residual[layer_name] should be a dict containing:
                               - "weight": a torch.Tensor of updates for the selected units.
                               - "indices": a list of selected indices (integers).
                               - Optionally, "bias": a torch.Tensor of updates for bias.
      alpha (float): The weight for the global model (1 - alpha for the residual update).
      param_keys (list of str): Ordered list of parameter keys (as in self.net.state_dict().keys()).
                                This is used to map a layer name to the corresponding global parameter.
    
    Returns:
      updated_params (list of np.ndarray): The updated global parameters.
    """
    if param_keys is None:
        raise ValueError("param_keys must be provided (the ordering of self.net.state_dict() keys).")
    
    # Make a shallow copy of the global_params list so we can update it in place.
    updated_params = list(global_params)
    
    for module_name, updates in local_residual.items():
        # Construct keys for weight and bias for this module.
        weight_key = module_name + ".weight"
        bias_key = module_name + ".bias"
        
        # Update weight if present
        if weight_key in param_keys:
            idx_in_list = param_keys.index(weight_key)
            # Loop over each selected index and apply the convex combination.
            for pos, idx in enumerate(updates["indices"]):
                # Convert the residual update to a NumPy array.
                weight_update = updates["weight"][pos].cpu().numpy()
                # Perform the convex combination.
                updated_params[idx_in_list][idx] = (
                    alpha * updated_params[idx_in_list][idx] +
                    (1 - alpha) * weight_update
                )
                
        # Update bias if present
        if "bias" in updates and bias_key in param_keys:
            idx_in_list = param_keys.index(bias_key)
            for pos, idx in enumerate(updates["indices"]):
                bias_update = updates["bias"][pos].cpu().numpy()
                updated_params[idx_in_list][idx] = (
                    alpha * updated_params[idx_in_list][idx] +
                    (1 - alpha) * bias_update
                )
    
    return updated_params

### Server-side Decoding:

# At the server side, to reconstruct clearly from the `List[np.ndarray]`, parse the returned arrays sequentially. For instance:
def decode_client_updates(selected_updates):
    final_updated_params = []
    for client in selected_updates:
        client_selected = {}
        i = 0
        while i < len(client):
            # Decode layer name from bytes array
            layer_name = client[i].tobytes().decode('utf-8')
            indices = client[i + 1]
            weight_values = client[i + 2]
            i += 3

            # Check if the next array represents bias (shape will differ)
            has_bias = (i < len(client) and 
                        client[i].ndim == 1 and
                        client[i].size == len(indices))
            if has_bias:
                bias_values = client[i]
                i += 1
            else:
                bias_values = None

            client_selected[layer_name] = {
                "indices": indices.tolist(),
                "weight": torch.tensor(weight_values),
                "bias": torch.tensor(bias_values) if bias_values is not None else None
            }
        final_updated_params.append(client_selected)
    return final_updated_params


def get_base_key(full_key: str) -> str:
    # Split the key and join the first two parts with a dot
    return '.'.join(full_key.split('.')[:2])

import os
import pickle
from typing import Dict, List

import matplotlib.pyplot as plt

# Encoding list for the Shakespeare dataset
ALL_LETTERS = (
    "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
)


def _one_hot(
    index: int,
    size: int,
) -> List:
    """Return one-hot vector with given size and value 1 at given index."""
    vec = [0 for _ in range(size)]
    vec[int(index)] = 1
    return vec


def letter_to_vec(
    letter: str,
) -> int:
    """Return one-hot representation of given letter."""
    index = ALL_LETTERS.find(letter)
    return index


def word_to_indices(
    word: str,
) -> List:
    """Return a list of character indices.

    Parameters
    ----------
        word: string.

    Returns
    -------
        indices: int list with length len(word)
    """
    indices = []
    for count in word:
        indices.append(ALL_LETTERS.find(count))
    return indices


def update_ema(
    prev_ema: float,
    current_value: float,
    smoothing_weight: float,
) -> float:
    """We use EMA to visually enhance the learning trend for each round.

    Parameters
    ----------
    prev_ema : float
        The list of metrics to aggregate.
    current_value : float
        The list of metrics to aggregate.
    smoothing_weight : float
        The list of metrics to aggregate.


    Returns
    -------
    EMA_Loss or EMA_ACC
        The weighted average metric.
    """
    if prev_ema is None:
        return current_value
    return (1 - smoothing_weight) * current_value + smoothing_weight * prev_ema


def save_graph_params(data_info: Dict):
    """Save parameters to visualize experiment results (Loss, ACC).

    Parameters
    ----------
    data_info : Dict
        This is a parameter dictionary of data from which the experiment was completed.
    """
    if os.path.exists(f"{data_info['path']}/{data_info['algo']}.pkl"):
        raise ValueError(
            f"'{data_info['path']}/{data_info['algo']}.pkl' is already exists!"
        )

    with open(f"{data_info['path']}/{data_info['algo']}.pkl", "wb") as file:
        pickle.dump(data_info, file)


def plot_from_pkl(directory="."):
    """Visualization of algorithms like 4 Algorithm for data.

    Parameters
    ----------
    directory : str
        Graph params directory path for Femnist or Shakespeare
    """
    color_mapping = {
        "fedavg.pkl": "#66CC00",
        "fedavg_meta.pkl": "#3333CC",
        "fedmeta_maml.pkl": "#FFCC00",
        "fedmeta_meta_sgd.pkl": "#CC0000",
    }

    pkl_files = [f for f in os.listdir(directory) if f.endswith(".pkl")]

    all_data = {}

    for file in pkl_files:
        with open(os.path.join(directory, file), "rb") as file_:
            data = pickle.load(file_)
            all_data[file] = data

    plt.figure(figsize=(7, 12))

    # Acc graph
    plt.subplot(2, 1, 1)
    for file in sorted(all_data.keys()):
        data = all_data[file]
        accuracies = [acc for _, acc in data["accuracy"]["accuracy"]]
        legend_ = file[:-4] if file.endswith(".pkl") else file
        plt.plot(
            accuracies,
            label=legend_,
            color=color_mapping.get(file, "black"),
            linewidth=3,
        )
    plt.title("Accuracy")
    plt.grid(True)
    plt.legend()

    plt.subplot(2, 1, 2)
    for file in sorted(all_data.keys()):
        data = all_data[file]
        loss = [loss for _, loss in data["loss"]]
        legend_ = file[:-4] if file.endswith(".pkl") else file
        plt.plot(
            loss, label=legend_, color=color_mapping.get(file, "black"), linewidth=3
        )
    plt.title("Loss")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    save_path = f"{directory}/result_graph.png"
    plt.savefig(save_path)

    plt.show()

