import time
import torch
import torch.nn.functional as F

from misc.utils import *
from models.nets import *
from modules.federated import ClientModule
import wandb
import os

import global_var as gvr

class Client(ClientModule):

    def __init__(self, args, w_id, g_id, sd):
        super(Client, self).__init__(args, w_id, g_id, sd)
        self.args = args
        self.g_id = g_id
        
    def init_state(self): 

        init_k = torch.tensor(self.loader.init_k[self.client_id]).cuda(self.g_id)
        
        self.model = LGCN(init_k, self.args.n_feat, self.args.n_dims, self.args.n_clss, self.args).cuda(self.g_id) 

        self.parameters = list(self.model.parameters()) 
        self.optimizer = torch.optim.Adam(self.parameters, lr=self.args.base_lr, weight_decay=self.args.weight_decay)

        self.log = {
            'lr': [],'train_lss': [],
            'ep_local_val_lss': [],'ep_local_val_acc': [],
            'rnd_local_val_lss': [],'rnd_local_val_acc': [],
            'ep_local_test_lss': [],'ep_local_test_acc': [],
            'rnd_local_test_lss': [],'rnd_local_test_acc': [],
            'rnd_sparsity':[], 'ep_sparsity':[]
        }

        self.csv_log = {
            'model' : [self.args.model],
            'dataset': [self.args.dataset],
            'mode': [self.args.mode],
            'nclients': [self.args.n_clients],
            'lr': [self.args.base_lr],
            'seed': [self.args.seed],
            'dims': [self.args.n_dims],
            'nfeat': [self.args.n_feat],
            'nclss': [self.args.n_clss],
            'loc-l2': [self.args.loc_l2 * 10],
            'nsamps': [],
            'rnd': [],
            'gpu': [self.gpu_id],
            'cid': [],
            'ep': [],
            'c': [],
            'val_lss': [],
            'test_lss': [],
            'val_acc': [],
            'test_acc':[],
            'notes': [f'Init_k: {self.loader.init_k}; Learnable_k: {self.args.learnable_k}; Optimizer: {self.args.optimizer}; Classifier: {self.args.classifier}; Loss: {self.args.lss_func}; Rescale: {self.args.rescale}; loc_l2:{self.args.loc_l2}']
        }

    def save_state(self):
        torch_save(self.args.checkpt_path, f'{self.client_id}_state.pt', {
            'optimizer': self.optimizer.state_dict(),
            'model': get_state_dict(self.model),
            'log': self.log,
        })

    def load_state(self):
        loaded = torch_load(self.args.checkpt_path, f'{self.client_id}_state.pt')
        set_state_dict(self.model, loaded['model'], self.gpu_id)
        self.optimizer.load_state_dict(loaded['optimizer'])
        self.log = loaded['log']
    
    def on_receive_message(self, curr_rnd):
        self.curr_rnd = curr_rnd
        assert self.loader.client_id == self.client_id
        
        if f'personalized_{self.client_id}' in self.sd:
            self.update(self.model, self.sd['global'])  
        else:
            self.update(self.model, self.sd['global'])
            

    def update(self, model, update):
        self.prev_w = convert_np_to_tensor(update['model'], self.gpu_id)
        self.set_fedhyp_state_dict(model, update['model'], self.gpu_id, skip_stat=True)

    # update part of the params for clients
    def set_fedhyp_state_dict(self, model, state_dict, gpu_id, skip_stat=False, skip_mask=False):
        state_dict = convert_np_to_tensor(state_dict, gpu_id, skip_stat=skip_stat, skip_mask=skip_mask, model=model.state_dict())

        updated_dict = model.state_dict()
        for key, value in state_dict.items():
            if 'weight.weight' in key:
                weight = value
                weight_old = updated_dict[key]
                updated_dict[key] = torch.cat((weight_old[:, :1], weight[:, 1:]), dim=1) # Cora（128，1434）
            else:
                updated_dict[key] = value
        model.load_state_dict(updated_dict)


    def on_round_begin(self):
        self.train()
        self.transfer_to_server()

    def train(self):
        st = time.time()
        val_local_acc, val_local_lss = self.validate(mode='valid')
        test_local_acc, test_local_lss = self.validate(mode='test')

        if self.args.wandb:
            wandb.log({f'test_acc_{self.loader.client_id}':test_local_acc},step=self.curr_rnd * (self.args.n_eps+1))
            wandb.log({f'val_acc_{self.loader.client_id}':val_local_acc},step=self.curr_rnd * (self.args.n_eps+1))

        self.logger.print(
            f'rnd: {self.curr_rnd+1}, ep: {0}, '
            f'val_local_loss: {val_local_lss.item():.4f}, val_local_acc: {val_local_acc:.4f}, lr: {self.get_lr()} ({time.time()-st:.2f}s)'
        )
        self.log['ep_local_val_acc'].append(val_local_acc)
        self.log['ep_local_val_lss'].append(val_local_lss)
        self.log['ep_local_test_acc'].append(test_local_acc)
        self.log['ep_local_test_lss'].append(test_local_lss)

        if self.args.csv:
            import pandas as pd
            if self.loader.client_id != self.client_id:
                print('client.py line 159:', self.loader.client_id == self.client_id)
            if self.args.csv:
                self.csv_log['nsamps'] = 0 #TODO
            self.csv_log['c'] = self.model.manifold.k.data
            self.csv_log['cid'] = self.loader.client_id
            self.csv_log['rnd'] = self.curr_rnd + 1
            self.csv_log['ep'] = 0
            self.csv_log['val_lss'] = val_local_lss
            self.csv_log['val_acc'] = val_local_acc
            self.csv_log['test_lss'] = test_local_lss
            self.csv_log['test_acc'] = test_local_acc
            file = self.args.csv_path
            if not os.path.exists(file):
                df = pd.DataFrame(self.csv_log)
                df.to_csv(file)
            else:
                df = pd.DataFrame(self.csv_log)
                df.to_csv(file, mode='a', header=False)
        
        for ep in range(self.args.n_eps):
            st = time.time()
            self.model.train()
            for _, batch in enumerate(self.loader.pa_loader):
                self.optimizer.zero_grad()

                edge_old = batch.edge_index
                batch.edge_index = torch.sparse_coo_tensor(edge_old, torch.ones(edge_old.shape[1]))

                zero_dim = torch.zeros(len(batch.x), 1)
                x_e = torch.cat([zero_dim, batch.x], dim=1)
                batch.x = x_e

                batch = batch.cuda(self.gpu_id)
                if self.args.csv:
                    self.csv_log['nsamps'] = len(batch.x)
                y_hat = self.model(batch, self.loader.k, self.gpu_id)

                train_lss = self.model.compute_lss(y_hat[batch.train_mask], batch.y[batch.train_mask])
                
                ################################################################
                for name, param in self.model.state_dict().items():
                    if name in self.prev_w.keys():
                        if self.curr_rnd == 0: continue
                        train_lss += torch.norm(param.float() - self.prev_w[name], 2) * self.args.loc_l2
                ################################################################

                train_lss.backward(retain_graph=True)
                self.optimizer.step()
                if self.args.optimizer == 'dual':
                    self.optimizer_cls.step()

            val_local_acc, val_local_lss = self.validate(mode='valid')
            test_local_acc, test_local_lss = self.validate(mode='test')

            if self.args.wandb:
                wandb.log({f'test_acc_{self.loader.client_id}':test_local_acc},step=self.curr_rnd * (self.args.n_eps+1) + ep + 1)
                wandb.log({f'val_acc_{self.loader.client_id}':val_local_acc},step=self.curr_rnd * (self.args.n_eps+1) + ep + 1)

            self.logger.print(
                f'rnd:{self.curr_rnd+1}, ep:{ep+1}, '
                f'val_local_loss: {val_local_lss.item():.4f}, val_local_acc: {val_local_acc:.4f}, lr: {self.get_lr()} ({time.time()-st:.2f}s)'
            )
            self.logger.print(
                f'rnd:{self.curr_rnd+1}, ep:{ep+1}, '
                f'test_local_loss: {test_local_lss.item():.4f}, test_local_acc: {test_local_acc:.4f}, lr: {self.get_lr()} ({time.time()-st:.2f}s)'
            )
            self.log['train_lss'].append(train_lss.item())
            self.log['ep_local_val_acc'].append(val_local_acc)
            self.log['ep_local_val_lss'].append(val_local_lss)
            self.log['ep_local_test_acc'].append(test_local_acc)
            self.log['ep_local_test_lss'].append(test_local_lss)

        self.log['rnd_local_val_acc'].append(val_local_acc)
        self.log['rnd_local_val_lss'].append(val_local_lss)
        self.log['rnd_local_test_acc'].append(test_local_acc)
        self.log['rnd_local_test_lss'].append(test_local_lss)
        self.save_log()


        if self.args.csv:
            if self.loader.client_id != self.client_id:
                print('client.py line 159:', self.loader.client_id == self.client_id)
            self.csv_log['c'] = self.model.manifold.k.data
            self.csv_log['cid'] = self.loader.client_id
            self.csv_log['rnd'] = self.curr_rnd + 1
            self.csv_log['ep'] = ep+1
            self.csv_log['val_lss'] = val_local_lss
            self.csv_log['val_acc'] = val_local_acc
            self.csv_log['test_lss'] = test_local_lss
            self.csv_log['test_acc'] = test_local_acc
            df = pd.DataFrame(self.csv_log)
            df.to_csv(file, mode='a', header=False)
            

    def transfer_to_server(self):
        self.sd[self.client_id] = { 
            'model': get_partial_state_dict(self.model, gvr.HYP_AGG_KEYWORDS),
            'whole_model': get_state_dict(self.model),
            'curvature': get_state_dict(self.model)['manifold.k'],
            'train_size': len(self.loader.partition)
        }