import torch
import torch.optim as optim
from resnet import build_model
import time
from dataset import EmnistDataset
from torch.utils.data import DataLoader
from flwr.common import FitRes, Status
from util import get_parameters, set_parameters
from flwr.common import ndarrays_to_parameters
from flwr.common import Code
from proto_operators import agg_func, get_covariance
from hyper_params import DEVICE, LAMBDA, Batch

def local_train(cid, params, server_round, client_count, E=5, learning_rate=0.005, Btc=Batch, classifier_weights=None) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_1.0.csv")
    trainloader = DataLoader(dataset, Btc, shuffle=True)
    localmodel = build_model(cid, device=DEVICE)
    set_parameters(localmodel, params)
    optimizer = optim.SGD(localmodel.parameters(), lr=learning_rate)
    
    criterion = torch.nn.CrossEntropyLoss()
    time1 = time.time()
    localmodel.train()
    prototypes = {}
    feature_space = {}
    count_byclass = {}

    if classifier_weights == None:
        pass
    else:
        with torch.no_grad():
            weight = torch.from_numpy(classifier_weights[0]).to(DEVICE)
            bias = torch.from_numpy(classifier_weights[1]).to(DEVICE)

    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs, protos = localmodel(samples)
            if e >= epoch - 1:
                for p, l in zip(protos, labels):
                    label_idx = l.item()
                    if label_idx in prototypes.keys():
                        prototypes[label_idx].append(p)
                        feature_space[label_idx].append(p.flatten())
                        count_byclass[label_idx] += 1
                    else:
                        prototypes[label_idx] = [p]
                        feature_space[label_idx] = [p.flatten()]
                        count_byclass[label_idx] = 1
                    
            loss = criterion(outputs, labels) 
            if classifier_weights == None:
                pass
            else:
                prox = (localmodel.fc.weight - weight).norm(2) + (localmodel.fc.bias - bias).norm(2)
                loss = loss + LAMBDA * prox
            loss.backward()
            optimizer.step()           
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    
    proto_average_byclass = agg_func(prototypes)
    proto_means = {}
    for l, mz in proto_average_byclass.items():
        proto_means[l] = [mz]

    parameters_updated = get_parameters(localmodel)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(dataset), 
                  metrics={"proto":proto_average_byclass, "cov":get_covariance(feature_space), "count":count_byclass})