import torch
import torch.optim as optim
import time
import torch.nn as nn
from dataset import EmnistDataset
from resnet import build_model, ResNet18
from torch.utils.data import DataLoader
from flwr.common import FitRes, Status
from util import get_parameters, set_parameters
from flwr.common import ndarrays_to_parameters
from flwr.common import Code
from hyper_params import DEVICE, LAMBDA
import torch.nn.functional as F

Batch = 16

def local_train(cid, params, global_params, server_round, client_count, E=5, learning_rate=0.005, Btc=Batch) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_1.0.csv")
    trainloader = DataLoader(dataset, Btc, shuffle=True)
    localmodel = build_model(cid, device=DEVICE)
    proxy_model = ResNet18().to(DEVICE)
    
    set_parameters(localmodel, params)
    if global_params == None:
        pass
    else:
        set_parameters(proxy_model, global_params)

    optimizer = optim.SGD(localmodel.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    time1 = time.time()
    localmodel.train()
    criterion2 = torch.nn.KLDivLoss()
    sm = nn.Softmax()

    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = localmodel(samples)
            loss1 = criterion(outputs, labels)
            if global_params == None:
                loss = loss1
            else:
                with torch.no_grad():
                    y_global, _ = proxy_model(samples)
                loss2 = criterion2(F.log_softmax(outputs), sm(y_global))
                loss = loss1 + LAMBDA * loss2
            loss.backward()
            optimizer.step()           
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")

    parameters_updated = get_parameters(localmodel)
    del localmodel
    del proxy_model

    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(dataset), 
                  metrics={})

