import torch
import torch.optim as optim
import time
from resnet import build_model, ResNet18
from dataset import EmnistDataset, EmnistDataset_synthetic
from torch.utils.data import DataLoader
from flwr.common import FitRes, Status
from util import set_parameters, get_parameters
from flwr.common import ndarrays_to_parameters
from flwr.common import Code
from proto_operators import agg_func
from transpose_conv import TC_net
from hyper_params import Z_DIM, DEVICE, LAMBDA

Batch = 16

def local_train(cid, params, server_round, client_count, Config, E=5, learning_rate=0.005, Btc=Batch, proto_Z=None) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_1.0.csv")
    trainloader = DataLoader(dataset, Btc, shuffle=True)
    localmodel = build_model(cid, device=DEVICE)
    set_parameters(localmodel, params)
    optimizer = optim.SGD(localmodel.parameters(), lr=learning_rate)

    tc_param = Config["tc"]
    sigmas = Config['sigma']
    tc_model = TC_net(in_channels=Z_DIM).to(DEVICE)
    set_parameters(tc_model, tc_param)

    criterion = torch.nn.CrossEntropyLoss()
    time1 = time.time()
    localmodel.train()
    prototypes = {}
    feature_space = {}
    count_byclass = {}
    criterion2 = torch.nn.MSELoss()

    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs, protos = localmodel(samples)
            if e >= epoch - 1:
                for p, l in zip(protos, labels):
                    label_idx = l.item()
                    if label_idx in prototypes.keys():
                        prototypes[label_idx].append(p)
                        feature_space[label_idx].append(p.flatten())
                        count_byclass[label_idx] += 1
                    else:
                        prototypes[label_idx] = [p]
                        feature_space[label_idx] = [p.flatten()]
                        count_byclass[label_idx] = 1
                    
            if proto_Z == None:
                loss = criterion(outputs, labels)
            else:
                loss1 = criterion(outputs, labels)
                loss2 = 0.0
                for p, l in zip(protos, labels):
                    c = proto_Z[l.item()]
                    loss2 += criterion2(p.float(), c.float()) / sigmas[l.item()]
                loss = loss1 + (loss2 / Btc) * LAMBDA
            loss.backward()
            optimizer.step()           
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")

    proto_average_byclass = agg_func(prototypes)
    
    if proto_Z != None:
        tc_params, sd = train_TC_parameters(proto_Z, tc_model, localmodel, trainloader, sigmas, epoch, DEVICE)
    else:
        tc_params, sd = get_parameters(tc_model), sigmas
    
    parameters_updated = get_parameters(localmodel)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(dataset), 
                  metrics={"proto":proto_average_byclass, "count":count_byclass, 
                           "tc":tc_params, "sigma":sd})

def local_extra_train(cid, params, server_round, client_count, Config, learning_rate=0.005, Btc=Batch, ft_epoch=1) -> FitRes:
    epoch = ft_epoch
    print(f"Extra round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset1 = EmnistDataset("clientdata/femnist_client_"+ str(cid) + "_ALPHA_1.0.csv")
    trainloader1 = DataLoader(dataset1, Btc, shuffle=True)
    dataset2 = EmnistDataset_synthetic()
    trainloader2 = DataLoader(dataset2, Btc, shuffle=True)
    localmodel = build_model(cid, device=DEVICE)
    set_parameters(localmodel, params)
    optimizer = optim.SGD(localmodel.parameters(), lr=learning_rate)

    criterion = torch.nn.CrossEntropyLoss()
    time1 = time.time()
    localmodel.train()

    for _ in range(epoch):

        for samples, labels in trainloader2:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = localmodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        for samples, labels in trainloader1:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs, _ = localmodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    
    parameters_updated = get_parameters(localmodel)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(dataset1), 
                  metrics={})


def train_TC_parameters(proto_Z, TC_model:TC_net, local_model:ResNet18, dataloader:DataLoader, sigmas, epoch=10, Device='cuda'):

    print(f"Start optimizing TC parameters......")

    tc_optimizer = torch.optim.Adam(TC_model.parameters(), lr=0.001, betas=(0.5, 0.999))
    
    sd = {}
    for idx in sigmas.keys():
        sd[idx] = torch.sqrt(torch.tensor(sigmas[idx],device=Device)).requires_grad_()
          
    sd_optimizer = torch.optim.Adam(sd.values(), lr=1e-5, betas=(0.5, 0.999))
    criterion = torch.nn.MSELoss()
    local_model.freeze_all_params()    # The local model will not get trained when training TC.
    
    for _ in range(epoch):
        protos = {}
        tc_optimizer.zero_grad()
        sd_optimizer.zero_grad()
        for samples, labels in dataloader: 

            kl_mean_loss = 0.0     # used to optimize TC params

            samples, labels = samples.to(Device), labels.to(Device)   
            _, features = local_model(samples)
            x_hats = TC_model(features)
            _, z_hats = local_model(x_hats)

            total = 0
            for z_hat, l in zip(z_hats, labels):
                idx = l.item()
                c = proto_Z[idx].detach().clone().to(Device)
                kl_mean_loss += criterion(z_hat.float(), c.float()) / sigmas[idx]
                total += 1
                if idx not in protos.keys():
                    protos[idx] = None

            kl_mean_loss = kl_mean_loss / total
            kl_mean_loss.backward()
            tc_optimizer.step()
        for idx in protos.keys():
            sd_optimizer.zero_grad()
            kl_var_loss = -torch.log(torch.square(sd[idx])) + torch.square(sd[idx]) / sigmas[idx]
            kl_var_loss.backward()
            sd_optimizer.step()
    
    updated_vars = {}
    with torch.no_grad():
        for zdx in sd.keys():
            updated_vars[zdx] = torch.square(sd[zdx].detach().clone()).item()

    print(f"training TC params done")
    return get_parameters(TC_model), updated_vars
