from generator import Generator, Predictor
import random
import torch
import time
from flwr.common import parameters_to_ndarrays
from flwr.server.strategy.aggregate import aggregate
from util import get_parameters, set_parameters
from datetime import datetime
from hyper_params import CLASSES, DEVICE, Z_Dim, Batch
from FedVTC import get_local_test_acc
from Fedgen_local import local_train
from resnet import build_model

Num_clients, Num_participants = 100, 10
ROUNDS = 100

def run_Fedgen(M=Num_clients, P=Num_participants, R=ROUNDS, seed=2024, lr=0.05, local_epoch=5, Device=DEVICE, Batch=16):
    global_predictor = Predictor(Z_Dim, CLASSES).to(Device)
    generator = Generator(z=Z_Dim, classes=CLASSES).to(Device)
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    
    local_params = {}
    for cid in range(M):
        localmodel = build_model(cid, device=Device)
        local_params[cid] = get_parameters(localmodel)
        del localmodel
    
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")
        fit_results = []
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        gen_param = get_parameters(generator)
        for c in clients:
            cid = str(c)
            local_param = local_params[c]
            if i == 0:
                fitres = local_train(cid, local_param, None, i, client_count+1, E=local_epoch, learning_rate=lr, Btc=Batch)
            else:
                fitres = local_train(cid, local_param, gen_param, i, client_count+1, E=local_epoch, learning_rate=lr, Btc=Batch)
            client_count += 1
            classifier_param = fitres.metrics["classifier"]
            param = parameters_to_ndarrays(fitres.parameters)
            local_params[c] = param
            fit_results.append((classifier_param, fitres.num_examples))
        # Aggregate:
        print(f"Aggregating and updating global model.....\n")
        new_model_dict = aggregate(fit_results)
        set_parameters(global_predictor, new_model_dict)
        print(f"Training generator......")
        train_generator(generator, global_predictor)

        # Evaluate mean accuracy of local models on local datasets:
        print(f"Round {i+1}, evaluating......")
        time1 = time.time()
        if (i+1) % 10 == 0:
            total_acc = 0.0
            for cid in range(M):
                local_param = local_params[cid]
                _, acc = get_local_test_acc(cid, local_param, local_device=Device)
                total_acc += acc
            test_accuracies.append(total_acc / M)
            print(f"FedGen: Round {i+1} completed, test accuracy = {total_acc / M}, time consumed = {time1-time0}")
        else:
            print("Skipping evaluation to save time...")
            print(f"FedGen: Round {i+1} completed, time consumed = {time1-time0}")
        time0 = time1
    
    now = datetime.now()
    with open('results/FedGen_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)

def train_generator(generator:Generator, classifier:Predictor, classes=CLASSES, global_repoch=5, device=DEVICE):
    
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    criterion = torch.nn.CrossEntropyLoss()
    batched_labels = []

    for c in range(classes):
        y_list = []
        for _ in range(Batch):
            y_list.append(torch.tensor(c, dtype=torch.long))
        y_batch = torch.stack(y_list)
    batched_labels.append(y_batch)

    for _ in range(global_repoch):
        g_optimizer.zero_grad()
        for ys in batched_labels:
            ys = ys.to(device)
            zs = generator(ys)
            y_pred = classifier(zs)
            loss = criterion(y_pred, ys)
            loss.backward()
            g_optimizer.step()
