import time
import copy
import torch
import numpy as np
import torch.nn.functional as F

from misc.utils import *
from models.nets import *
from modules.federated import ClientModule

class Client(ClientModule):

    def __init__(self, args, w_id, g_id, sd):
        super(Client, self).__init__(args, w_id, g_id, sd)
        self.model = MaskedGCN(self.args.n_feat, self.args.n_dims, self.args.n_clss, self.args.l1, self.args).cuda(g_id) 
        self.parameters = list(self.model.parameters()) 

    def init_state(self):
        self.optimizer = torch.optim.Adam(self.parameters,
                            lr=self.args.base_lr,
                            weight_decay=self.args.weight_decay)
        self.log = {
            'lr': [],'train_lss': [],
            'ep_local_val_lss': [],'ep_local_val_acc': [],
            'rnd_local_val_lss': [],'rnd_local_val_acc': [],
            'ep_local_test_lss': [],'ep_local_test_acc': [],
            'rnd_local_test_lss': [],'rnd_local_test_acc': [],
            'ep_global_val_lss': [],'ep_global_val_acc': [],
            'rnd_global_val_lss': [],'rnd_global_val_acc': [],
            'ep_global_test_lss': [],'ep_global_test_acc': [],
            'rnd_global_test_lss': [],'rnd_global_test_acc': [],
            'rnd_sparsity':[], 'ep_sparsity':[]
        }

    def save_state(self):
        torch_save(self.args.checkpt_path, f'{self.client_id}_state.pt', {
            'optimizer': self.optimizer.state_dict(),
            'model': get_state_dict(self.model),
            'log': self.log,
        })

    def load_state(self):
        loaded = torch_load(self.args.checkpt_path, f'{self.client_id}_state.pt')
        set_state_dict(self.model, loaded['model'], self.gpu_id)
        self.optimizer.load_state_dict(loaded['optimizer'])
        self.log = loaded['log']
    
    def on_receive_message(self, curr_rnd):
        self.curr_rnd = curr_rnd
        self.update(self.sd[f'adaptive_{self.client_id}' \
            if (f'adaptive_{self.client_id}' in self.sd) else 'global'])
        self.global_w = convert_np_to_tensor(self.sd['global']['model'], self.gpu_id)

    def update(self, update):
        self.prev_w = convert_np_to_tensor(update['model'], self.gpu_id)
        set_state_dict(self.model, update['model'], self.gpu_id, skip_stat=True, skip_mask=True)

    def on_round_begin(self, curr_rnd):
        self.train()
        self.transfer_to_server()

    def train(self):
        st = time.time()
        val_global_acc, val_global_lss = self.validate(mode='global')
        val_local_acc, val_local_lss = self.validate(mode='local')
        test_global_acc, test_global_lss = self.evaluate(mode='global')
        test_local_acc, test_local_lss = self.evaluate(mode='local')
        self.logger.print(
            f'rnd:{self.curr_rnd+1}, ep:{0}, '
            f'val_global_loss:{val_global_lss.item():.4f}, val_global_acc:{val_global_acc:.4f}, '
            f'val_local_loss:{val_local_lss.item():.4f}, val_local_acc:{val_local_acc:.4f}, lr:{self.get_lr()} ({time.time()-st:.2f}s)'
        )
        self.log['ep_global_val_acc'].append(val_global_acc)
        self.log['ep_global_val_lss'].append(val_global_lss)
        self.log['ep_global_test_acc'].append(test_global_acc)
        self.log['ep_global_test_lss'].append(test_global_lss)
        self.log['ep_local_val_acc'].append(val_local_acc)
        self.log['ep_local_val_lss'].append(val_local_lss)
        self.log['ep_local_test_acc'].append(test_local_acc)
        self.log['ep_local_test_lss'].append(test_local_lss)

        self.masks = []
        for name, param in self.model.state_dict().items():
            if 'mask' in name and self.args.mask_rank == -1:
                self.masks.append(param) 

        # measuring sparsity per epoch for this case is not working
        if self.args.mask_rank != -1:
            for module in self.model.children():
                self.masks.append(module.mask)

        for ep in range(self.args.n_eps):
            st = time.time()
            self.model.train()
            for i, batch in enumerate(self.loader.pa_loader):
                self.optimizer.zero_grad()
                batch = batch.cuda(self.gpu_id)
                y_hat = self.model(batch)
                train_lss = F.cross_entropy(y_hat[batch.train_mask], batch.y[batch.train_mask])
                
                #################################################################
                for name, param in self.model.state_dict().items():
                    if 'mask' in name and self.args.mask_rank == -1:
                        train_lss += torch.norm(param.float(), 1) * self.args.l1
                    elif 'conv' in name or 'clsif' in name:
                        if self.curr_rnd > 0:
                            train_lss += torch.norm(param.float()-self.prev_w[name], 2) * self.args.loc_l2

                if self.args.mask_rank != -1:
                    for module in self.model.children():
                        train_lss += torch.norm(module.mask.float(), 1) * self.args.l1
                #################################################################
                        
                train_lss.backward()
                self.optimizer.step()

            n_active = 0 
            n_total = 1
            for mask in self.masks:
                pruned = torch.abs(mask) < self.args.l1
                mask = torch.ones(mask.shape).cuda(self.gpu_id).masked_fill(pruned, 0)
                n_active += torch.sum(mask)
                _n_total = 1
                for s in mask.shape:
                    _n_total *= s 
                n_total += _n_total
            sparsity = ((n_total-n_active)/n_total).item()

            val_global_acc, val_global_lss = self.validate(mode='global')
            val_local_acc, val_local_lss = self.validate(mode='local')
            test_global_acc, test_global_lss = self.evaluate(mode='global')
            test_local_acc, test_local_lss = self.evaluate(mode='local')
            self.logger.print(
                f'rnd:{self.curr_rnd+1}, ep:{ep+1}, '
                f'test_global_acc:{test_global_acc:.4f}, test_local_acc:{test_local_acc:.4f}, lr:{self.get_lr()} ({time.time()-st:.2f}s)'
            )
            self.log['train_lss'].append(train_lss.item())
            self.log['ep_global_val_acc'].append(val_global_acc)
            self.log['ep_global_val_lss'].append(val_global_lss)
            self.log['ep_global_test_acc'].append(test_global_acc)
            self.log['ep_global_test_lss'].append(test_global_lss)
            self.log['ep_local_val_acc'].append(val_local_acc)
            self.log['ep_local_val_lss'].append(val_local_lss)
            self.log['ep_local_test_acc'].append(test_local_acc)
            self.log['ep_local_test_lss'].append(test_local_lss)
            self.log['ep_sparsity'].append(sparsity)
        self.log['rnd_global_val_acc'].append(val_global_acc)
        self.log['rnd_global_val_lss'].append(val_global_lss)
        self.log['rnd_global_test_acc'].append(test_global_acc)
        self.log['rnd_global_test_lss'].append(test_global_lss)
        self.log['rnd_local_val_acc'].append(val_local_acc)
        self.log['rnd_local_val_lss'].append(val_local_lss)
        self.log['rnd_local_test_acc'].append(test_local_acc)
        self.log['rnd_local_test_lss'].append(test_local_lss)
        self.log['rnd_sparsity'].append(sparsity)
        self.save_log()

    @torch.no_grad()
    def get_proxy_output(self):
        self.model.eval()
        with torch.no_grad():
            proxy_in = self.sd['proxy']
            proxy_in = proxy_in.cuda(self.gpu_id)
            proxy_out = self.model(proxy_in, is_proxy=True)
            proxy_out = proxy_out.mean(dim=0)
            proxy_out = proxy_out.clone().detach().cpu().numpy()
        return proxy_out

    def transfer_to_server(self):
        self.sd[self.client_id] = {
            'model': get_state_dict(self.model),
            'train_size': len(self.loader.partition),
            'proxy': self.get_proxy_output()
        }




    
    
