import copy
import os
from collections import OrderedDict
from typing import List

import numpy as np
import torch
from flwr.client import NumPyClient

from src._train import train
from src.util import test, get_label_dist


class FLOCOClient(NumPyClient):

    def __init__(self, client_name, net, local_net, trainloader, valloader, cfg, experiment_name):
        self.client_name = client_name
        self.trainloader = trainloader
        self.valloader = valloader
        self.cfg = cfg
        self.label_dist = get_label_dist(trainloader, cfg.dataset_model.num_classes)
        self.global_net = net
        self.local_net = local_net
        self.client_state_save_path = f"trained_models/local_client_models/{experiment_name}/{client_name}"

    def fit(self, parameters, config):
        flwr_set_parameters(self.global_net, parameters)
        self.last_params = copy.deepcopy(flwr_get_parameters(self.global_net))

        # Check for local models, e.g. DITTO
        if self.local_net and os.path.exists(self.client_state_save_path):
            last_local_net_state_dict = torch.load(self.client_state_save_path + f"/last_state_dict.pth")
            self.local_net.load_state_dict(last_local_net_state_dict)

        # print(f'BEFORE client {self.client_name} - alpha_net weights: {self.global_net.alpha_net.fc1.weight.sum()}')

        metrics_dict = train(
            server_round=config['server_round'],
            global_net=self.global_net,
            local_net=self.local_net,
            trainloader=self.trainloader,
            cfg=self.cfg,
            client_alphas=config.get("alpha", None)
        )

        # Save the trained local model
        if self.local_net:
            os.makedirs(self.client_state_save_path, exist_ok=True)
            torch.save(self.local_net.state_dict(),
                       self.client_state_save_path + f"/last_state_dict.pth")

        trained_params = flwr_get_parameters(self.global_net)
        
        if self.cfg.strategy.strategy_name == "FLOCO":
            params = [self.last_params[i] - trained_params[i] for i in range(len(trained_params))]
        else:
            params = trained_params

        metrics_dict["train_label_dist"] = self.label_dist

        return params, len(self.trainloader), metrics_dict

    def evaluate(self, parameters, config):
        flwr_set_parameters(self.global_net, parameters)

        if self.local_net:
            if os.path.exists(self.client_state_save_path):
                last_local_net_state_dict = torch.load(self.client_state_save_path + f"/last_state_dict.pth")
                self.local_net.load_state_dict(last_local_net_state_dict)
            net = self.local_net
        else:
            net = self.global_net

        metrics_dict = test(
            net=net,
            valloader=self.valloader,
            client_alpha=config.get("alpha", None),
            num_classes=self.cfg.dataset_model.num_classes,
            device=self.cfg.device,
            strategy_name=self.cfg.strategy.strategy_name,
        )
        return metrics_dict["val_loss"], len(self.valloader), metrics_dict


def flwr_get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def flwr_set_parameters(net, parameters: List[np.ndarray]):
    state_dict = OrderedDict({k: torch.tensor(np.atleast_1d(v))
                              for k, v in zip(net.state_dict().keys(), parameters)})
    net.load_state_dict(state_dict, strict=True)