import torch
import torch.optim as optim
import time
from dataset import EmnistDataset
from resnet import build_model
from generator import Generator
from torch.utils.data import DataLoader
from flwr.common import FitRes, Status
from util import set_parameters, get_filters, set_filters
from flwr.common import ndarrays_to_parameters
from flwr.common import Code
from hyper_params import DEVICE

Batch = 16

def local_train(cid, params, gen_param, server_round, client_count, E=5, learning_rate=0.005, Btc=Batch) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_1.0.csv")
    trainloader = DataLoader(dataset, Btc, shuffle=True, drop_last=True)
    localmodel = build_model(cid, device=DEVICE)
    set_filters(localmodel, params)
    if gen_param == None:
        generator = None
    else:
        generator = Generator().to(DEVICE)
        set_parameters(generator, gen_param)

    optimizer = optim.SGD(localmodel.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    time1 = time.time()
    localmodel.train()

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(localmodel.parameters(), lr=learning_rate)
    localmodel.train()
    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = localmodel(samples)
            if not (gen_param == None):
                #print(f"shape of labels = {labels.shape}")
                z_syn = generator(labels)
                y_pred = localmodel.predict(z_syn)
                loss2 = criterion(y_pred, labels)
            else:
                loss2 = 0.0 
            loss = criterion(outputs, labels) + loss2
            loss.backward()
            optimizer.step()   
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")

    parameters_updated = get_filters(localmodel)
    classifier = get_classifier(localmodel)
    del localmodel
    del generator

    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(dataset), 
                  metrics={"classifier": classifier})


def get_classifier(net:torch.nn.Module):
    classifier_names = ['fc.weight', 'fc.bias']
    params_list = []
    for k, v in net.state_dict().items():
        if k in classifier_names:
            params_list.append(v.cpu().numpy())       
    return params_list
