import time
import torch
import torch.nn.functional as F
from misc.utils import *
from data.loader import DataLoader
from modules.logger import Logger
from collections import OrderedDict
import numpy as np

class ServerModule:
    def __init__(self, args, sd, gpu_server):
        self.args = args
        self._args = vars(self.args)
        self.gpu_id = gpu_server
        self.sd = sd
        self.loader = DataLoader(self.args)
        self.logger = Logger(self.args, self.gpu_id, is_server=True)

    def get_active(self, mask):
        active = np.absolute(mask) >= self.args.l1
        return active.astype(float)


    def aggregate(self, local_weights, ratio=None, params_to_update=None):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if params_to_update is None:
            params_to_update = [k for k in local_weights[0].keys() if 'lora_' in k]

        aggr_theta = OrderedDict()
        for name in params_to_update:
            aggr_theta[name] = torch.zeros_like(local_weights[0][name], device=device)

        for name in aggr_theta:
            if ratio is not None:
                for j, weights in enumerate(local_weights):
                    aggr_theta[name] += weights[name] * ratio[j]
            else:
                uniform_ratio = 1 / len(local_weights)
                for weights in local_weights:
                    aggr_theta[name] += weights[name] * uniform_ratio

        return aggr_theta

    
    def aggregate_all(self, local_weights, ratio=None):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        local_weights = self.convert_weights_to_tensor(local_weights, device)
        aggr_theta = OrderedDict([(k, None) for k in local_weights[0].keys()])

        if ratio is None:
            ratio = [1 / len(local_weights)] * len(local_weights)

        for name in aggr_theta.keys():
            if all(isinstance(theta[name], torch.Tensor) for theta in local_weights):
                max_shape = max([theta[name].shape for theta in local_weights], key=lambda x: x.numel())
                padded_params = [torch.zeros(max_shape, device=device) for _ in local_weights]

                for j, theta in enumerate(local_weights):
                    param = theta[name]
                    slices = tuple(slice(0, s) for s in param.shape)
                    padded_params[j][slices] = param

                aggr_theta[name] = torch.sum(torch.stack([p * ratio[j] for j, p in enumerate(padded_params)]), dim=0)
            else:
                aggr_theta[name] = local_weights[0][name]

        return aggr_theta


    def convert_weights_to_tensor(self, client_weights, device):
        tensor_weights = []
        for weights in client_weights:
            tensor_weights.append({
                k: torch.tensor(v, device=device) if isinstance(v, (list, np.ndarray)) else v
                for k, v in weights.items()
            })
        return tensor_weights
    


class ClientModule:
    def __init__(self, args, w_id, g_id, sd):
        self.sd = sd
        self.gpu_id = g_id
        self.worker_id = w_id
        self.args = args
        self._args = vars(self.args)
        self.loader = DataLoader(self.args)
        self.logger = Logger(self.args, self.gpu_id)

    def switch_state(self, client_id):
        self.client_id = client_id
        self.loader.switch(client_id)
        self.logger.switch(client_id)
        if self.is_initialized():
            time.sleep(0.1)
            self.load_state()
        else:
            self.init_state()

    def chance_state(self, client_state, client_id):
        self.client_id = client_id
        self.loader.switch(client_id)
        self.logger.switch(client_id)
        if client_state[client_id] != {}:
            self.load_state1(client_state[client_id])
        else:
            self.init_state()

    def is_initialized(self):
        return os.path.exists(os.path.join(self.args.checkpt_path, f'{self.client_id}_state.pt'))

    @property
    def init_state(self):
        raise NotImplementedError()

    @property
    def save_state(self):
        raise NotImplementedError()

    @property
    def load_state(self):
        raise NotImplementedError()

    @property
    def load_state1(self, loaded):
        raise NotImplementedError()

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

    def save_log(self):
        save(self.args.log_path, f'client_{self.client_id}.txt', {
            'args': self._args,
            'log': self.log
        })

    def get_optimizer_state(self, optimizer):
        state = {}
        for param_key, param_values in optimizer.state_dict()['state'].items():
            state[param_key] = {}
            for name, value in param_values.items():
                if torch.is_tensor(value) == False: continue
                state[param_key][name] = value.clone().detach().cpu().numpy()
        return state
