import torch.nn.functional as F

from misc.utils import *
from data.loader import DataLoader
from modules.logger import Logger

class ServerModule:
    def __init__(self, args, sd, gpu_server):
        self.args = args
        self._args = vars(self.args)
        self.gpu_id = gpu_server
        self.sd = sd
        self.loader = DataLoader(self.args, is_server=True)
        self.logger = Logger(self.args, self.gpu_id, is_server=True)

    def get_active(self, mask):
        active = np.absolute(mask) >= self.args.l1
        return active.astype(float)

    def aggregate(self, local_weights, ratio=None):
        st = time.time()
        aggr_theta = OrderedDict([(k,None) for k in local_weights[0].keys()])
        if ratio is not None:
            for name, params in aggr_theta.items():
                if self.args.mask_aggr:
                    if 'mask' in name:
                        # get active
                        acti = [ratio[i]*self.get_active(lw[name])+1e-8 for i, lw in enumerate(local_weights)]
                        # get element_wise ratio
                        elem_wise_ratio = acti/np.sum(acti, 0)
                        # perform element_wise aggr
                        aggr_theta[name] = np.sum([theta[name]*elem_wise_ratio[j] for j, theta in enumerate(local_weights)], 0)
                    else:
                        aggr_theta[name] = np.sum([theta[name]*ratio[j] for j, theta in enumerate(local_weights)], 0)
                else:
                    aggr_theta[name] = np.sum([theta[name]*ratio[j] for j, theta in enumerate(local_weights)], 0)
        else:
            ratio = 1/len(local_weights)
            for name, params in aggr_theta.items():
                aggr_theta[name] = np.sum([theta[name] * ratio for j, theta in enumerate(local_weights)], 0)
        # self.logger.print(f'weight aggregation done ({round(time.time()-st, 3)} s)')
        return aggr_theta

    @torch.no_grad()
    def evaluate(self):
        if not self.args.eval_global:
            return 0, np.mean([0])

        with torch.no_grad():
            target, pred, loss = [], [], []
            for i, batch in enumerate(self.loader.te_loader):
                batch = batch.cuda(self.gpu_id)
                y_hat, lss = self.validation_step(batch, batch.test_mask)
                pred.append(y_hat[batch.test_mask])
                target.append(batch.y[batch.test_mask])
                loss.append(lss)
            acc = self.accuracy(torch.stack(pred).view(-1, self.args.n_clss), torch.stack(target).view(-1))
        return acc, np.mean(loss)

    @torch.no_grad()
    def validate(self):
        if not self.args.eval_global:
            return 0, np.mean([0])

        with torch.no_grad():
            target, pred, loss = [], [], []
            for i, batch in enumerate(self.loader.va_loader):
                batch = batch.cuda(self.gpu_id)
                y_hat, lss = self.validation_step(batch, batch.val_mask)
                pred.append(y_hat[batch.val_mask])
                target.append(batch.y[batch.val_mask])
                loss.append(lss)
            acc = self.accuracy(torch.stack(pred).view(-1, self.args.n_clss), torch.stack(target).view(-1))
        return acc, np.mean(loss)

    @torch.no_grad()
    def validation_step(self, batch, mask=None):
        self.model.eval()
        y_hat = self.model(batch)
        if torch.sum(mask).item() == 0: return y_hat, 0.0
        lss = F.cross_entropy(y_hat[mask], batch.y[mask])
        return y_hat, lss.item()

    @torch.no_grad()
    def accuracy(self, preds, targets):
        if targets.size(0) == 0: return 1.0
        with torch.no_grad():
            preds = preds.max(1)[1]
            acc = preds.eq(targets).sum().item() / targets.size(0)
        return acc

    def save_log(self):
        save(self.args.log_path, f'server.txt', {
            'args': self._args,
            'log': self.log
        })

class ClientModule:
    def __init__(self, args, w_id, g_id, sd):
        self.sd = sd
        self.gpu_id = g_id
        self.worker_id = w_id
        self.args = args 
        self._args = vars(self.args)
        self.loader = DataLoader(self.args)
        self.logger = Logger(self.args, self.gpu_id)
       
    def switch_state(self, client_id):
        self.client_id = client_id
        self.loader.switch(client_id)
        self.logger.switch(client_id)
        if self.is_initialized():
            time.sleep(0.1)
            self.load_state()
        else:
            self.init_state()

    def is_initialized(self):
        return os.path.exists(os.path.join(self.args.checkpt_path, f'{self.client_id}_state.pt'))

    @property
    def init_state(self):
        raise NotImplementedError()

    @property
    def save_state(self):
        raise NotImplementedError()

    @property
    def load_state(self):
        raise NotImplementedError()

    @torch.no_grad()
    def evaluate(self, mode='global'):
        if mode == 'global' and not self.args.eval_global:
            return 0, np.mean([0])

        if mode == 'global':    loader = self.loader.te_loader
        elif mode == 'local':   loader = self.loader.pa_loader
        else:                   raise ValueError()
        
        with torch.no_grad():
            target, pred, loss = [], [], []
            for i, batch in enumerate(loader):
                batch = batch.cuda(self.gpu_id)
                y_hat, lss = self.validation_step(batch, batch.test_mask)
                pred.append(y_hat[batch.test_mask])
                target.append(batch.y[batch.test_mask])
                loss.append(lss)
            acc = self.accuracy(torch.stack(pred).view(-1, self.args.n_clss), torch.stack(target).view(-1))
        return acc, np.mean(loss)

    @torch.no_grad()
    def evaluate_neighbor(self):
        loader = self.loader.ne_loader
        
        with torch.no_grad():
            target, pred, loss = [], [], []
            for i, batch in enumerate(loader):
                batch = batch.cuda(self.gpu_id)
                y_hat, lss = self.validation_step(batch, batch.test_mask)
                pred.append(y_hat[batch.test_mask])
                target.append(batch.y[batch.test_mask])
                loss.append(lss)
            acc = self.accuracy(torch.stack(pred).view(-1, self.args.n_clss), torch.stack(target).view(-1))
        return acc, np.mean(loss)

    @torch.no_grad()
    def validate(self, mode='global'):
        if mode == 'global' and not self.args.eval_global:
            return 0, np.mean([0])

        if mode == 'global':    loader = self.loader.va_loader
        elif mode == 'local':   loader = self.loader.pa_loader
        else:                   raise ValueError()

        with torch.no_grad():
            target, pred, loss = [], [], []
            for i, batch in enumerate(loader):
                batch = batch.cuda(self.gpu_id)
                y_hat, lss = self.validation_step(batch, batch.val_mask)
                pred.append(y_hat[batch.val_mask])
                target.append(batch.y[batch.val_mask])
                loss.append(lss)
            acc = self.accuracy(torch.stack(pred).view(-1, self.args.n_clss), torch.stack(target).view(-1))
        return acc, np.mean(loss)

    @torch.no_grad()
    def validation_step(self, batch, mask=None):
        self.model.eval()
        y_hat = self.model(batch)
        if torch.sum(mask).item() == 0: return y_hat, 0.0
        lss = F.cross_entropy(y_hat[mask], batch.y[mask])
        return y_hat, lss.item()

    @torch.no_grad()
    def accuracy(self, preds, targets):
        if targets.size(0) == 0: return 1.0
        with torch.no_grad():
            preds = preds.max(1)[1]
            acc = preds.eq(targets).sum().item() / targets.size(0)
        return acc

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

    def save_log(self):
        save(self.args.log_path, f'client_{self.client_id}.txt', {
            'args': self._args,
            'log': self.log
        })

    def get_optimizer_state(self, optimizer):
        state = {}
        for param_key, param_values in optimizer.state_dict()['state'].items():
            state[param_key] = {}
            for name, value in param_values.items():
                if torch.is_tensor(value) == False: continue
                state[param_key][name] = value.clone().detach().cpu().numpy()
        return state
