import os
import wandb
import torch
from torch.utils.data import DataLoader
from typing import Callable, Dict, Optional, Tuple
from collections import OrderedDict
from hydra.utils import instantiate
from omegaconf import DictConfig
import numpy as np
from flwr.common import (
    Code,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.typing import (
    Callable,
    Dict,
    GetParametersIns,
    List,
    NDArrays,
    Optional,
    Tuple,
    Union,
)

from utils_helper import segment_resnet_parameters, reconstruct_parameters, flatten_resnet_parameters
from models import test

def gen_evaluate_fn(
    testloader: DataLoader,
    device: torch.device,
    model: DictConfig,
    cfg_wandb: bool,
) -> Callable[
    [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]
]:
    """Generate the function for centralized evaluation.

    Parameters
    ----------
    testloader : DataLoader
        The dataloader to test the model with.
    device : torch.device
        The device to test the model on.

    Returns
    -------
    Callable[ [int, NDArrays, Dict[str, Scalar]],
               Optional[Tuple[float, Dict[str, Scalar]]] ]
    The centralized evaluation function.
    """

    def evaluate(server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar]) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Use the entire Emnist test set for evaluation."""
        net = instantiate(model)
        print("Model instantiated")
        save_dir = "server_model_checkpoints"

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        if server_round > 0:
            load_checkpoint_path = f"{save_dir}/model_checkpoint_round_{server_round-1}.pth"
            net.load_state_dict(torch.load(load_checkpoint_path, weights_only=True))
            print(f"Model checkpoint loaded from {load_checkpoint_path}")

        if server_round == 0:
            params_dict = zip(net.state_dict().keys(), parameters_ndarrays)
            aggregated_state_dict = OrderedDict()

            for k, v in params_dict:
                try:
                    if k.endswith('num_batches_tracked'):
                        tensor = torch.tensor(v, dtype=torch.long)
                    else:
                        tensor = torch.Tensor(v)
                    aggregated_state_dict[k] = tensor
                except Exception as e:
                    print(f"Error converting parameter {k}: {e}")
        else:
            print(f"I was right {server_round} with {len(parameters_ndarrays[0])} parameters")
            prev_state_dict_list = segment_resnet_parameters(flatten_resnet_parameters(net.state_dict()), num_segments=2)

            # Received segment from the global model
            params_dict = torch.Tensor(parameters_ndarrays[0])
            prev_state_dict_list[(server_round + 1) % 2] = params_dict

            # Concatenate the modified segments to form the full flattened parameters
            modified_flat_params = torch.cat(prev_state_dict_list)

            print("State dict created")
            # Reconstruct the original parameters from the modified flattened parameters
            prev_state_dict = net.state_dict()
            shapes = {k: v.shape for k, v in prev_state_dict.items()}
            sizes = {k: v.numel() for k, v in prev_state_dict.items()}
            aggregated_state_dict = reconstruct_parameters(
                modified_flat_params, shapes, sizes, prev_state_dict
            )

        try:
            net.load_state_dict(aggregated_state_dict, strict=True)
        except Exception as e:
            print(f"Error loading state dict with strict=True: {e}")
            net.load_state_dict(aggregated_state_dict, strict=False)

        print("Model loaded with aggregated state dict, Trying to move to device", device)
        net.to(device)
        print("Model moved to device")
        loss, accuracy = test(net, testloader, device=device)
        if cfg_wandb == True:
            wandb.log({"acc": accuracy, "loss": loss})

        # Save the model checkpoint
        save_checkpoint_path = f"{save_dir}/model_checkpoint_round_{server_round}.pth"
        torch.save(net.state_dict(), save_checkpoint_path)
        print(f"Model checkpoint saved at {save_checkpoint_path}")

        return loss, {"accuracy": accuracy}

    return evaluate


def fedavg_gen_evaluate_fn(
    testloader: DataLoader,
    device: torch.device,
    model: DictConfig,
    cfg_wandb: bool,
) -> Callable[
    [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]
]:
    """Generate the function for centralized evaluation.

    Parameters
    ----------
    testloader : DataLoader
        The dataloader to test the model with.
    device : torch.device
        The device to test the model on.

    Returns
    -------
    Callable[ [int, NDArrays, Dict[str, Scalar]],
               Optional[Tuple[float, Dict[str, Scalar]]] ]
    The centralized evaluation function.
    """

    def evaluate(
        server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        # pylint: disable=unused-argument
        """Use the entire Emnist test set for evaluation."""
        net = instantiate(model)
        params_dict = zip(net.state_dict().keys(), parameters_ndarrays)
        aggregated_state_dict = OrderedDict()
        for k, v in params_dict:
            try:
                if k.endswith('num_batches_tracked'):
                    tensor = torch.tensor(v, dtype=torch.long)
                else:
                    tensor = torch.Tensor(v)
                aggregated_state_dict[k] = tensor
            except Exception as e:
                print(f"Error converting parameter {k}: {e}")
        try:
            net.load_state_dict(aggregated_state_dict, strict=True)
        except Exception as e:
            print(f"Error loading state dict with strict=True: {e}")
            net.load_state_dict(aggregated_state_dict, strict=False)
        net.to(device)

        loss, accuracy = test(net, testloader, device=device)
        if cfg_wandb == True:
            wandb.log({"acc": accuracy, "loss": loss})
        return loss, {"accuracy": accuracy}

    return evaluate


def fedspars_gen_evaluate_fn(
    testloader: DataLoader,
    device: torch.device,
    model: DictConfig,
    approach: str,
    comp_type: str,
    sparsify_by: float,
    partitioning: str,
    num_clients: int,
    clients_per_round: int,
    cfg_wandb: bool,
) -> Callable[
    [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]
]:
    """Generate the function for centralized evaluation.

    Parameters
    ----------
    testloader : DataLoader
        The dataloader to test the model with.
    device : torch.device
        The device to test the model on.

    Returns
    -------
    Callable[ [int, NDArrays, Dict[str, Scalar]],
               Optional[Tuple[float, Dict[str, Scalar]]] ]
    The centralized evaluation function.
    """

    def evaluate(server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar]) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Use the entire Emnist test set for evaluation."""
        net = instantiate(model)
        print("Model instantiated")
        save_dir = "server_model_checkpoints_"+approach
        save_dir = "server_model_checkpoints_"+approach+"_"+comp_type+"_"+str(sparsify_by)+"_"+partitioning+"_"+str(num_clients)+"_"+str(clients_per_round)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        params_dict = zip(net.state_dict().keys(), parameters_ndarrays)
        aggregated_state_dict = OrderedDict()

        for k, v in params_dict:
            try:
                if k.endswith('num_batches_tracked'):
                    tensor = torch.tensor(v, dtype=torch.long)
                else:
                    tensor = torch.Tensor(v)
                aggregated_state_dict[k] = tensor
            except Exception as e:
                print(f"Error converting parameter {k}: {e}")
        try:
            net.load_state_dict(aggregated_state_dict, strict=True)
        except Exception as e:
            print(f"Error loading state dict with strict=True: {e}")
            net.load_state_dict(aggregated_state_dict, strict=False)

        print("Model loaded with aggregated state dict, Trying to move to device", device)
        net.to(device)
        print("Model moved to device")
        loss, accuracy = test(net, testloader, device=device)
        if cfg_wandb == True:
            wandb.log({"acc": accuracy, "loss": loss})

        # # Save the model checkpoint
        if server_round == 0:
            save_checkpoint_path = f"{save_dir}/model_checkpoint_round_{server_round}.pth"
            torch.save(net.state_dict(), save_checkpoint_path)
            print(f"Model checkpoint saved at {save_checkpoint_path}")

        return loss, {"accuracy": accuracy}

    return evaluate