"""
An aggressive version that reduces the influence from GAN, but emphize EP-prior
regularization.

Move the EP-prior to discriminator ...
"""


import json
import math
import os
import random
import sys
import time
#!!! please modify these hyperprameters manually
import warnings
from shutil import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import tqdm
from sru import SRU

from compute_priors import compute_priors
from dataGenSequences_sru import dataGenSequences
from lib.ops import Dense

EPS = 1e-5
SQRT2 = math.sqrt(2)
POW2_3O2 = math.pow(2, 1.5)

warnings.filterwarnings('ignore')
# this depend on the feature you applied
mfccDim=40
seed = random.randint(0,10000)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed) # cpu 
torch.cuda.manual_seed_all(seed) # gpu 
torch.backends.cudnn.deterministic = True


if __name__ != '__main__':
    raise ImportError ('This script can only be run, and can\'t be imported')

if len(sys.argv) != 8 and len(sys.argv) != 11:
    raise TypeError ('USAGE: train.py data_cv ali_cv data_tr ali_tr gmm_dir dnn_dir init_lr [turns] [epoch_per_turn]')


data_cv = sys.argv[1]
ali_cv  = sys.argv[2]
data_tr = sys.argv[3]
ali_tr  = sys.argv[4]
gmm     = sys.argv[5]
exp     = sys.argv[6]
init_lr = float(sys.argv[7])
try:
    clients = int(sys.argv[8])
    turns = int(sys.argv[9])
    epoch_per_turn = int(sys.argv[10])
except:
    turns = 12
    epoch_per_turn = 1
    clients = 16

##!!! please modify these hyperprameters manually
## Learning parameters
z_size              = 64
num_gens            = 10
num_dics            = 16
g_batch_size        = 128
g_noise_loss_lambda = 3e-2
d_noise_loss_lambda = 3e-2
d_hist_loss_lambda  = 1
g_ep_prior_lambda   = 1e-3
d_ep_prior_lambda   = 1e-3
s_ep_prior_lambda   = 1e-3
d_loss_lambda       = 1e-1
g_loss_lambda       = 1e-1
noise_std           = np.sqrt(0.02)

learning = {'rate' : init_lr,
            'singFeaDim' : mfccDim, 
            'minEpoch' : 30,
            'batchSize' : 128,#40 at first
            'timeSteps' : 20,
            'dilDepth' : 1,
            'minValError' : 0,
            'left' : 0,
            'right': 4,
            'hiddenDim' : 1280,
            'modelOrder' : 1,
            'layerNum': 12,# 12 at first
            'historyNum' : 1,
            'z_size':z_size,
            'num_gens': num_gens,
            'num_dics': num_dics,
            'disDim':40}

## Copy final model and tree from GMM directory
os.makedirs(exp, exist_ok=True)
copy(gmm + '/final.mdl', exp)
copy(gmm + '/tree', exp)


def return_my_self(batch):
    x=[]
    y=[]
    try:
        for xx,yy in batch:
            x.append(xx)
            y.append(yy)
        return torch.stack(x), torch.stack(y)
    except:
        print("Error")
        print(batch)
        exit(0)

if clients == 16:
    sessions = [["S03"], ["S04"], ["S05"], ["S06"], 
                ["S07"], ["S17"], ["S08"], ["S16"], 
                ["S12"], ["S13"], ["S18"], ["S22"], 
                ["S19"], ["S20"],["S23"], ["S24"]]
elif clients == 8:
    sessions = [["S03", "S04"], ["S05", "S06"], 
                ["S07", "S17"], ["S08", "S16"], 
                ["S12", "S13"], ["S18", "S22"], 
                ["S19", "S20"], ["S23", "S24"]]
elif clients == 4:
    sessions = [["S03", "S04", "S08", "S16"],
                ["S05", "S12", "S19", "S23", "S24"],
                ["S06", "S07", "S17"],
                ["S13", "S18", "S22", "S20"]]
else:
    raise NotImplementedError

## Compute priors
compute_priors(exp, ali_tr, ali_cv)

# The input feature of the neural network has this form:  0-1-4 features
feaDim = (learning['left'] + learning['right'] + 1) * mfccDim # 5 x 40 = 200
# discriminator output as feature extension
disDim = learning['disDim']


# load data from data iterator
trDatasets = [dataGenSequences(data_tr+'/'+str(i), ali_tr, gmm, learning['batchSize'], learning['timeSteps'], 
                    feaDim, learning['left'], learning['right'], 
                    my_sess=sessions[i-1]) for i in range(1, clients+1)]
cvDataset = dataGenSequences(data_cv, ali_cv, gmm, learning['batchSize'], learning['timeSteps'], feaDim,
                             learning['left'], learning['right'])

# Recommend shuffle=False, because this iterator's shuffle can only work on the single split
trGens = [data.DataLoader(trDatasets[i], batch_size=learning['batchSize'], 
                    shuffle=False, num_workers=0, collate_fn=return_my_self) \
                    for i in range(clients)]
client_weights = [trDatasets[i].numFeats for i in range(clients)]

cvGen = data.DataLoader(cvDataset, batch_size=learning['batchSize'], 
                        shuffle=False, num_workers=0, collate_fn=return_my_self)


##load the configurations from the training data
learning['targetDim'] = trDatasets[0].outputFeatDim

with open(exp + '/learning.json', 'w') as json_file:
    json_file.write(json.dumps(learning))

def FedAvg(params, weights=None):
    if not weights:
        weights = [1] * len(params)
    # assert len(params) == len(weights), (len(params), len(weights))
    avg_params = []
    for _, param in enumerate(zip(*params)):
        avg_params.append(weighted_sum(param, weights))
    return avg_params

def weighted_sum(inputs, weights):
    weights_sum = sum(weights)
    # print(weights[0] / weights_sum)
    res = 0
    for i, w in zip(inputs, weights):
        res += i * (w / weights_sum)
    return res

def noise_loss(model, noise_sampler, alpha):
    loss = 0
    for p, n in zip(model.parameters(), noise_sampler):
        n.normal_(mean=0, std=alpha)
        loss += torch.sum(p * n)
    return loss

def get_sghmc_noise(model):
    def _to_tensor(x):
        return torch.tensor(x).cuda()
    return [_to_tensor(torch.zeros(p.size())) for p in model.parameters()]

def mean(a):
    assert type(a) == list 
    return sum(a) / len(a)

def calc_statistics(x, y, a, b):
    ksei_square = math.pi / 8
    nu = a * (x + b)
    de = torch.sqrt(1 + ksei_square * a * a * y)
    # print("nu", nu.item(), "de", de.item(), "x", x.item(), "y", y.item())
    return torch.sigmoid(nu / (de + EPS))

def calc_template(a, b, c, d):
    return (((a + b) * c) - (b * d))

def get_ep_prior(theta_m, theta_s, fx):
    fx = torch.relu(fx).mean(0).squeeze()
    fxsq = fx * fx
    cm, cs = fx * theta_m, fxsq * theta_s
    # cb, cd = cm / (cs + EPS), -0.5 / (cs + EPS)
    e_1 = calc_statistics(cm, cs, 1, 0)
    e_2 = calc_statistics(cm, cs, 4 - 2 * SQRT2, math.log(SQRT2 + 1))
    e_3 = calc_statistics(cm, cs, 6 * (1 - 1 / POW2_3O2), math.log(POW2_3O2 - 1))

    _p_1 = calc_template(cm, cs, e_1, e_2)
    _p_2 = calc_template(cm, 2 * cs, e_2, e_3)
    s_0 = e_1
    s_1 = _p_1 / (s_0 * fx + EPS)
    s_2 = (cs * e_1 + calc_template(cm, cs, _p_1, _p_2)) / (s_0 * fxsq + EPS)

    # print()
    # print()
    # cb, cd = s_1 / (s_2 - s_1 * s_1), -1 / (2 * (s_2 - s_1 * s_1))
    theta_m, theta_s = s_1, torch.relu(s_2 - s_1 * s_1) + EPS

    try:
        assert theta_s.min() >= 0, "negative value found in theta_s"
    except:
        print(f"fx {fx.sum().item()}, cm {cm.sum().item()}, cs {cs.sum().item()}")
        print(f"e1 {e_1.sum().item()}, e2 {e_2.sum().item()}, e3 {e_3.sum().item()}")
        print(f"s0 {s_0.sum().item()}, s1 {s_1.sum().item()}, s2 {s_2.sum().item()}")
        print(theta_s.sum().item())
        exit(0)

    del cm, cs, e_1, e_2, e_3, s_0, s_1, s_2
    return theta_m, theta_s

def ep_prior_loss(model, m, s, observed=1e7):
    loss = 0
    for p in model.parameters():
        loss += torch.norm(
            (p - torch.ones_like(p, requires_grad=False).cuda() * m), 2) ** 2
    loss /= 2 * (s + EPS)
    # print(observed)
    return loss / observed

def merge_ep_prior(client_m, client_s):
    lmd = [(1/(s+EPS)) for s in client_s]
    theta_s = 1/(sum(lmd)+EPS)
    theta_m = theta_s * sum([m*l for m,l in zip(client_m, lmd)])
    return theta_m, theta_s
    
class multi_generator(nn.Module):
    """
        Generative Network
    """

    def __init__(self, hidden_size, noise_size=100, out_size=feaDim, timestep=200, num_gens=10):
        super(multi_generator, self).__init__()
        self.hidden_size = hidden_size
        self.noise_size = noise_size
        self.out_size = out_size
        self.timestep = timestep

        """
        generate a fixed feature of 20x200
        """
        self.gs = []
        for i in range(num_gens):
            g = nn.Sequential( #hidden x 4
                nn.ConvTranspose1d(self.noise_size, self.hidden_size, 4, 2),
                nn.BatchNorm1d(self.hidden_size),
                nn.ReLU(inplace=True)
            )
            setattr(self, f'G_{i}', g)
            self.gs.append(g)
        
        self.main = nn.Sequential( # hidden/2 x 10
                nn.ConvTranspose1d(self.hidden_size, self.hidden_size // 2, 4, 2),
                nn.BatchNorm1d(self.hidden_size // 2),
                nn.ReLU(inplace=True),
                # out_size x 20
                nn.ConvTranspose1d(self.hidden_size // 2, self.out_size, 2, 2),
                nn.BatchNorm1d(self.out_size),
        )

        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x):
        # input noise: batch_size x noise_size x 1
        sp_size = (len(x) - 1) // len(self.gs) + 1
        y = []
        for _x, _g in zip(torch.split(x, sp_size, dim=0), self.gs):
            y.append(_g(_x))
        y = torch.cat(y, dim=0)
        output = self.main(y)
        return output.transpose(1, 2)


class discriminator(nn.Module):
    def __init__(self, input_size=feaDim, hidden_size=1024, output_size=1095, 
                  num_dics=learning['num_dics'], 
                  num_layers=learning['layerNum'], 
                  z_size=learning['num_dics']):
        super(discriminator, self).__init__()

        self.hidden_size = hidden_size
        self.num_dics = num_dics
        self.num_layers = num_layers
        self.output_size = output_size
        self.Dense_layer1 = Dense(input_size, self.hidden_size)

        self.sru = SRU(input_size=self.hidden_size, 
                       hidden_size=self.hidden_size,
                       num_layers=self.num_layers,
                       dropout=0.1, 
                       use_tanh=True)

        self.Dense_layer3 = Dense(self.hidden_size, disDim)

        self.ds = nn.Sequential(
            Dense(disDim, 64),
            nn.ReLU(inplace=True),
            Dense(64, num_dics)
        ) # discriminator = backbone + dis

        # hist model 
        self.ds_hist_avg = nn.Sequential(
            Dense(disDim, 64),
            nn.ReLU(inplace=True),
            Dense(64, num_dics)
        )
        for param in self.ds_hist_avg.parameters():
            param.requires_grad = False
        
        self.dropout = nn.Dropout(p=0.1)
        self.eps = 1e-3
        self.len_hist = 1.0

        if z_size:
            self.z_size = z_size
            self.get_z = nn.Sequential(
                Dense(disDim, 64),
                nn.ReLU(inplace=True),
                Dense(64, self.z_size)
            )
            self.get_pi = nn.Sequential(
                Dense(self.z_size * 2, 64),
                nn.ReLU(inplace=True),
                Dense(64, self.z_size)
            )



        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x, hidden=None):
        if hidden is None:
            hidden = torch.zeros(
                        learning['layerNum'], 
                        learning['batchSize'], 
                        learning['hiddenDim']
                            
            ).cuda()

        b, t, h = x.size()
        x = torch.reshape(x, (b*t, h))
        x = self.Dense_layer1(x)
        x = self.dropout(x)
        x = torch.reshape(x, (b, t, self.hidden_size))
        x = x.permute(1, 0, 2)
        x, hidden_after = self.sru(x, hidden)
        x = x.permute(1, 0, 2)
        b, t, h = x.size()
        x = torch.reshape(x, (b * t, h))
        x = self.Dense_layer3(x)
        x = torch.relu(x)
        x = self.dropout(x)

        return x, hidden_after

    def forward_to_get_z(self, x):
        x = self.get_z(x.view(-1, 20, disDim).mean(dim=1)) # b x t x z_size > mean pooling to b x z_size
        return x

    def forward_to_get_pi(self, m, s):
        output = self.get_pi(torch.cat([m, s]).unsqueeze(0))
        return F.gumbel_softmax(output, tau=0.1, hard=True).squeeze(0)

    def forward_ds(self, x):
        # x, hidden_after = self.forward_main(x, hidden)
        x = self.ds(x.view(-1, 20, disDim).mean(dim=1))
        return x

    def forward_by_hist(self, x):
        # x, hidden_after = self.forward_main(x, hidden)
        x = self.ds_hist_avg(x.view(-1, 20, disDim).mean(dim=1))
        return x

    def update_hist(self):
        self.len_hist += 1
        alpha = 1.0 / self.len_hist
        for trg, src in zip(self.ds_hist_avg.parameters(), self.ds.parameters()):
            trg.data = trg.data * (1 - alpha) + src.data * alpha


class SRU_ProbGAN_EP(nn.Module):
    def __init__(self, input_size=feaDim, hidden_size=1024, output_size=1095, 
                  num_layers=learning['layerNum']):
        super(SRU_ProbGAN_EP, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.Dense_layer1 = Dense(input_size, self.hidden_size)

        self.sru = SRU(input_size=self.hidden_size, 
                       hidden_size=self.hidden_size,
                       num_layers=self.num_layers,
                       dropout=0.1, 
                       use_tanh=True)

        self.Dense_layer3 = Dense(self.hidden_size, 1024)

        self.Dense_layer4 = Dense(1024, output_size)
        
        self.dropout = nn.Dropout(p=0.1)
        self.eps = 1e-3
        self.len_hist = 1.0

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x, hidden=None):
        if hidden is None:
            hidden = torch.zeros(
                        learning['layerNum'], 
                        learning['batchSize'], 
                        learning['hiddenDim']
                            
            ).cuda()

        b, t, h = x.size()
        x = torch.reshape(x, (b*t, h))
        x = self.Dense_layer1(x)
        x = self.dropout(x)
        x = torch.reshape(x, (b, t, self.hidden_size))
        x = x.permute(1, 0, 2)
        x, hidden_after = self.sru(x, hidden)
        x = x.permute(1, 0, 2)
        b, t, h = x.size()
        x = torch.reshape(x, (b * t, h))
        x = self.Dense_layer3(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.Dense_layer4(x).view(-1, self.output_size)

        return x, hidden_after


# If you run this code on CPU, please remove the '.cuda()'
model = SRU_ProbGAN_EP(input_size=feaDim+disDim, hidden_size=learning['hiddenDim'],
                    output_size=learning['targetDim']).cuda()
G = multi_generator(hidden_size=learning['hiddenDim'], noise_size=z_size).cuda()
D = discriminator(input_size=feaDim, hidden_size=learning['hiddenDim'],
                    output_size=learning['targetDim'], z_size=num_dics).cuda()
#model.load_state_dict(torch.load(exp + '/dnn.nnet.pth'))

pytorch_total_params = sum(p.numel() for p in G.parameters())
print("Total G parameters:", pytorch_total_params)
N_G_PARAMS = sum(p.numel() for p in G.parameters() if p.requires_grad)
print("Total G parameters to update:", N_G_PARAMS)

pytorch_total_params = sum(p.numel() for p in D.parameters())
print("Total D parameters:", pytorch_total_params)
N_D_PARAMS = sum(p.numel() for p in D.parameters() if p.requires_grad)
print("Total D parameters to update:", N_D_PARAMS)

pytorch_total_params = sum(p.numel() for p in model.parameters())
print("Total SRU parameters:", pytorch_total_params)
N_S_PARAMS = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total SRU parameters to update:", N_S_PARAMS)


loss_classify = nn.CrossEntropyLoss().cuda()
criterion_mse = nn.MSELoss().cuda()


print("Use objective: LS")
phi_1 = lambda dreal, lreal, lfake: criterion_mse(dreal, lreal)
phi_2 = lambda dfake, lreal, lfake: criterion_mse(dfake, lfake)
phi_3 = lambda dfake, lreal, lfake: criterion_mse(dfake, lreal)


optimizer = torch.optim.Adam(model.parameters(), lr=learning['rate'], betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(G.parameters(), lr=learning['rate'], betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(D.parameters(), lr=learning['rate'], betas=(0.5, 0.999))

            
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                        step_size=turns * epoch_per_turn // 2,
                                        gamma=1) # turn of the LR scheduling


# fixed_noise = torch.FloatTensor(10 * 10, z_size, 1).normal_(0, 1).cuda()
# fixed_noise.requires_grad = False

#optimizer = torch.optim.SGD( model.parameters(),lr=learning['rate'],momentum=0.5,nesterov=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)

def train_one_epoch(model, D, G, optimizer, optimizerD, optimizerG, 
                    tms, tss, client, epoch, hidden):
    acc = 0

    model.train()
    D.train()
    G.train()

    G_noise_sampler = [get_sghmc_noise(g) for g in G.gs]
    D_noise_sampler = get_sghmc_noise(D)

    for batch_idx, (x,y) in enumerate(tqdm.tqdm(trGens[client])):
        # If you run this code on CPU, please remove the '.cuda()'
        x = x.cuda()
        y = y.cuda()
        b, t, h = x.size()
        if b == learning['batchSize']:
            '''
                Update D network:
            '''
            G.eval()
            model.eval()
            D.train()
            if batch_idx == 0:
                #hidden_before = torch.from_numpy(hidden).cuda()
                hidden_before = hidden.cuda()
            else:
                #hidden_before = torch.from_numpy(hidden_after).cuda()
                hidden_before = hidden_after.cuda()


            real_rep, _ = D(x, hidden_before) # get representation

            # train discriminator with real data
            label_real = torch.ones(max(b, g_batch_size), requires_grad=False).cuda()
            D_real_result = D.forward_ds(real_rep).mean(-1) # b
            D_real_loss = phi_1(D_real_result, label_real[:b], None)
            # # train generator with fake data
            noise = torch.randn((g_batch_size, z_size, 1)).cuda()
            label_fake = torch.zeros(g_batch_size, requires_grad=False).cuda()
            G_result = G(noise)
            fake_rep, _ = D(G_result, hidden_before)
            D_fake_result = D.forward_ds(fake_rep).mean(-1)
            D_fake_loss = phi_2(D_fake_result, None, label_fake)
            # calculate GAN loss
            D_train_loss = D_real_loss + D_fake_loss
            D_noise_loss = d_noise_loss_lambda * noise_loss(model=D,
                noise_sampler=D_noise_sampler, alpha=noise_std)
            D_train_loss += D_noise_loss
            D_train_loss *= d_loss_lambda
            if batch_idx % 100 == 0:
                print(f"D_train_loss: {D_train_loss:.4f}")
            # back probagation
            D.zero_grad()
            D_train_loss.backward()
            optimizerD.step()
            

            '''
                Update G network:
            '''
            D.eval()
            model.eval()
            G.train()
            noise = torch.randn((g_batch_size, z_size, 1)).cuda()
            G_result = G(noise)
            fake_rep, _ = D(G_result)
            D_fake_result = D.forward_ds(fake_rep).mean(-1)
            G_train_loss = phi_3(D_fake_result, label_real, label_fake)
            G_noise_loss = g_noise_loss_lambda * \
                           sum([noise_loss(
                                    model=g, noise_sampler=s, alpha=noise_std)
                                for g, s in zip(G.gs, G_noise_sampler)])
            if batch_idx % 100 == 0:
                print(f"G_train_loss: {G_train_loss:.4f} G_noise_loss: {G_noise_loss:.4f}")
            G_train_loss += G_noise_loss
            D_fake_result_hist = D.forward_by_hist(fake_rep).mean(-1)
            G_train_loss_by_hist = phi_3(D_fake_result_hist, label_real, label_fake)
            G_train_loss += G_train_loss_by_hist * d_hist_loss_lambda

            # calculate EP loss
            f_x = D.forward_to_get_z(fake_rep)
            m, s = get_ep_prior(tms[client], tss[client], f_x)
            pis = D.forward_to_get_pi(m, s)
            G_ep_prior_loss = sum(
                [pi * ep_prior_loss(g, mm, ss, observed=N_D_PARAMS)
                    for pi, g, mm, ss, in zip(pis, G.gs, m, s)])
            G_train_loss += G_ep_prior_loss * g_ep_prior_lambda         
            G_train_loss *= g_loss_lambda

            if batch_idx % 100 == 0:
                # print(f"EP loss {G_train_ep_prior:.4f}")
                print(f"G total loss {G_train_loss:.4f}")
            # Back propagation
            G.zero_grad()
            G_train_loss.backward()
            optimizerG.step()

            if batch_idx % 10 == 0:
                """ update history discriminators aggregations
                """
                # logger.info(F.l1_loss(D.ds.weight.data, D.ds_hist_avg.weight.data))
                D.update_hist()

            """
                Update Classifier
            """
            model.train()
            D.train()
            G.eval()
            # Classifier optimization
            real_rep, _ = D(x, hidden_before)
            x = torch.cat([x, real_rep.view(b, t, -1)], dim=-1)
            output, hidden_after = model(x, hidden_before)
            y_batch_size, y_time_steps = y.size()
            y = torch.reshape(y, tuple([y_batch_size * y_time_steps]))
            y = y.long()
            loss = loss_classify(output, y)

            # calculate EP loss
            f_x = D.forward_to_get_z(real_rep)
            m, s = get_ep_prior(tms[client], tss[client], f_x)
            pis = D.forward_to_get_pi(m, s)
            D_ep_prior_loss = sum(
                [pi * ep_prior_loss(D, mm, ss, observed=N_D_PARAMS)
                    for pi, mm, ss, in zip(pis, m, s)])
            S_ep_prior_loss = sum(
                [pi * ep_prior_loss(model, mm, ss, observed=N_S_PARAMS)
                    for pi, mm, ss, in zip(pis, m, s)])

            if batch_idx % 100 == 0:
                print(
                    f"CE Loss: {loss:.4f}\t"
                    f"EP loss: {D_ep_prior_loss:.4f}")
            loss += D_ep_prior_loss * d_ep_prior_lambda
            loss += S_ep_prior_loss * s_ep_prior_lambda

            if batch_idx % 100 == 0:
                print(f"Total_SRU_loss: {loss:.4f}")

            model.zero_grad()
            D.zero_grad()
            loss.backward()
            optimizer.step()
            optimizerD.step()

            _, pred = torch.max(output.data, 1)
            hidden_after = hidden_after.detach()
            acc += ((pred == y).sum()).cpu().numpy()

            # update EP prior
            tms[client] = m.clone().detach()
            tss[client] = s.clone().detach()



            nn.utils.clip_grad_norm_(parameters=model.parameters(), 
                                    max_norm=100, norm_type=2)
            nn.utils.clip_grad_norm_(parameters=G.parameters(), 
                                    max_norm=100, norm_type=2)
            nn.utils.clip_grad_norm_(parameters=D.parameters(), 
                                    max_norm=100, norm_type=2)

    print("Train acc:", acc)
    print("Train numFeats:", trDatasets[client].numFeats)
    print("Train accuracy: %f"%(acc/trDatasets[client].numFeats))
    return model


def val(model, D, train_loader, hidden=None):
    if hidden is None:
        hidden = torch.zeros(
                learning['layerNum'], 
                learning['batchSize'], 
                learning['hiddenDim']
            ).cuda()

    model.eval()
    acc = 0
    val_loss = 0
    val_loss_list = []
    for batch_idx, (x, y) in enumerate(tqdm.tqdm(train_loader)):
        # If you run this code on CPU, please remove the '.cuda()'
        x = x.cuda()
        y = y.cuda()
        b, t, h = x.size()
        if b == learning['batchSize']:
            #model.zero_grad()
            optimizer.zero_grad()
            if batch_idx == 0:
                # hidden_before = torch.from_numpy(hidden).cuda()
                hidden_before = hidden.cuda()

            else:
                # hidden_before = torch.from_numpy(hidden_after).cuda()
                hidden_before = hidden_after.cuda()

            #output, hidden_after, cell_after = model(x, hidden_before, cell_before)
            #print(hidden_before.dtype)
            with torch.no_grad():
                real_rep, _ = D(x, hidden_before)
                output, hidden_after = model(
                           torch.cat([x, real_rep.view(b,t,-1)], dim=-1), 
                           hidden_before)
                #output = model(x)

            _, pred = torch.max(output.data, 1)
            y_batch_size, y_time_steps = y.size()
            y = torch.reshape(y, tuple([y_batch_size * y_time_steps]))
            y = y.long()
            loss = loss_classify(output, y)
            val_loss += float(loss.item())
            val_loss_list.append(val_loss)
            acc += ((pred == y).sum()).cpu().numpy()
            # if (batch_idx % 1000 == 0):#1000/60
            #     print("val:\t\tepoch:%d ,step:%d, loss:%f" % (epoch + 1, batch_idx, loss))

    print('Valid acc:', acc)
    print('Valid numFeats:', cvDataset.numFeats)
    val_acc = acc/cvDataset.numFeats
    print("Valid accuracy: %f" % (val_acc))
    print("Valid lOSS: %f" % (val_loss / len(val_loss_list)))
    return float(val_loss / len(val_loss_list)), val_acc


def train_net_local(model, D, G, optimizer, optimizerD, optimizerG, tms, tss,
                    client, vals, turn, max_epoch):
    val_loss_before = vals[client]
    tol_epoch = 0
    for epoch in range(max_epoch):
        print(f"====== Client: {client} trun-ep: {turn}-{epoch} =====")
        h = torch.zeros(
                learning['layerNum'], 
                learning['batchSize'], 
                learning['hiddenDim']
            )

        time_start = time.time()
        train_one_epoch(model, D, G, optimizer, optimizerD, optimizerG, 
                        tms, tss, client, epoch, h)
        # val_loss_after, _ = val(model, cvGen, loss_classify, optimizer, h)
        # if(val_loss_before - val_loss_after < 0):
            # scheduler.step()
            # tol_epoch += 1
        # else:
            # val_loss_before = val_loss_after
            # tol_epoch = 0
        # torch.save(model.state_dict(), exp + '/dnn.nnet.pth')
        time_end = time.time()
        time_cost = time_end - time_start
        print("Time Cost : %f"%(time_cost))
        if tol_epoch > 3:
            print(f"Client {client} local optimization finished"
                  f" at epoch {epoch}")
            break

    val_loss_after, _ = val(model, D, cvGen, h)
    vals[client]=val_loss_after
    return vals


def train_fl_net(model, D, G, optimizer, optimizerD, optimizerG, 
                 clients, turns, epoch_per_turn):
    server_model_parameters = [param.cpu() for param in model.parameters()]
    server_G_parameters = [param.cpu() for param in G.parameters()]
    server_D_parameters = [param.cpu() for param in D.parameters()]

    theta_m = torch.zeros(num_dics, requires_grad=False).cuda()
    theta_s = torch.ones(num_dics, requires_grad=False).cuda()

    vals = [10000 for i in range(clients)]
    for turn in range(turns):
        print('>>> current learning rate:', optimizer.param_groups[0]['lr'])
        client_model_parameters = []
        client_G_parameters = []
        client_D_parameters = []
        tms = [theta_m.clone() for _ in range(clients)]
        tss = [theta_s.clone() for _ in range(clients)]
        for client in range(clients):
            model.zero_grad()
            # broadcasting parameters
            for p, sp in zip(model.parameters(), server_model_parameters):
                p.data.copy_(sp.cuda())
            for p, sp in zip(G.parameters(), server_G_parameters):
                p.data.copy_(sp.cuda())
            for p, sp in zip(D.parameters(), server_D_parameters):
                p.data.copy_(sp.cuda())
            print(f"Runing turn {turn} on client {client}.")
            train_net_local(model, D, G, optimizer, optimizerD, optimizerG, 
                            tms, tss, client, vals, turn, epoch_per_turn)
            client_model_parameters.append(
                [param.cpu() for param in model.parameters()])
            client_G_parameters.append(
                [param.cpu() for param in G.parameters()])
            client_D_parameters.append(
                [param.cpu() for param in D.parameters()])

        server_model_parameters = FedAvg(client_model_parameters, client_weights)
        server_D_parameters = FedAvg(client_D_parameters, client_weights)
        server_G_parameters = FedAvg(client_G_parameters, client_weights)
        theta_m, theta_s = merge_ep_prior(tms, tss)

        scheduler.step()
        for p, sp in zip(model.parameters(), server_model_parameters):
            p.data.copy_(sp)
        for p, sp in zip(G.parameters(), server_G_parameters):
            p.data.copy_(sp)
        for p, sp in zip(D.parameters(), server_D_parameters):
            p.data.copy_(sp)            

        val_loss_after, _ = val(model, D, cvGen)
        print(f"Save the {turn} model with valid loss {val_loss_after}")
        torch.save(model.state_dict(), exp + '/dnn.nnet.pth')
        torch.save(D.state_dict(), exp + '/dnn.nnetD.pth')
        torch.save(G.state_dict(), exp + '/dnn.nnetG.pth')
    print("Finished ... ")



train_fl_net(model, D, G, optimizer, optimizerD, optimizerG, 
                clients, turns, epoch_per_turn)
