from typing import Dict
from resnet import build_model
import random
import torch
import time
import os
from flwr.common import parameters_to_ndarrays
from flwr.server.strategy.aggregate import aggregate
from dataset import EmnistDataset
from util import generate_random_protos_gaussian, get_parameters, set_parameters
from torch.utils.data import DataLoader
from datetime import datetime
from FedVTC_local import local_train, local_extra_train
from transpose_conv import TC_net
from torchvision.utils import save_image
from hyper_params import CHANNEL, CLASSES, DEVICE, Z_Dim

Batch = 16
Num_clients, Num_participants = 100, 10
ROUNDS = 100
Num_fake_samples = 500

def run_FedVTC(M=Num_clients, P=Num_participants, R=ROUNDS, extra_rounds=50, Device=DEVICE, 
              seed=2024, lr=1e-4, local_epoch=5, num_samples_per_class=CHANNEL, Batch=16):
    
    # Initialization:
    random.seed(seed)
    time0 = time.time()
    test_accuracies = []
    local_params = {}
    distill_images = {}
    global_protos = {}
    global_covs = {}
    global_sigmas = {}
    TC_model = TC_net().to(Device)

    for l in range(CLASSES):
        distill_images[l] = torch.randn(num_samples_per_class, CHANNEL, 28, 28, device=Device)
        global_protos[l] = torch.randn(Z_Dim, device=Device, requires_grad=False)
        global_covs[l] = torch.eye(Z_Dim, device=Device, requires_grad=False)
        global_sigmas[l] = 1.0
    for cid in range(M):
        localmodel = build_model(cid, device=Device)
        local_params[cid] = get_parameters(localmodel)
        del localmodel    
    
    # Start FL global iteration:
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")

        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        local_proto_list = []
        TC_list = []
        proto_average = {}
        local_sigmas_byclass, total_freq_byclass = {}, {}

        for c in clients:

            local_param = local_params[c]
            cid = str(c)
            config = {"tc": get_parameters(TC_model), "sigma":global_sigmas}

            if i == 0:
                training_result = local_train(cid, local_param, i, client_count+1, config, E=local_epoch, learning_rate=lr, Btc=Batch)
            else:
                training_result = local_train(cid, local_param, i, client_count+1, config, E=local_epoch, learning_rate=lr, Btc=Batch, proto_Z=global_protos)
            
            new_local_param = parameters_to_ndarrays(training_result.parameters)
            local_proto = training_result.metrics["proto"]
            local_sigma = training_result.metrics["sigma"]
            frequencies = training_result.metrics["count"]
            tc_param = training_result.metrics["tc"]

            # Get the total number of samples, prototypes and covariances (in Dict[Tuple] format) of each class on each client:
            for label in local_proto.keys():
                if label in proto_average.keys():
                    proto_average[label].append((local_proto[label].cpu().numpy(), frequencies[label]))
                    local_sigmas_byclass[label] += local_sigma[label] * frequencies[label] * frequencies[label]
                    total_freq_byclass[label] += frequencies[label]
                else:
                    proto_average[label] = [(local_proto[label].cpu().numpy(), frequencies[label])]
                    local_sigmas_byclass[label] = local_sigma[label] * frequencies[label] * frequencies[label]
                    total_freq_byclass[label] = frequencies[label]

            local_params[c] = new_local_param
            client_count += 1
            print(f"The {client_count}-th local training has been completed, cid = {cid}.\n")
            local_proto_list.append(local_proto)
            TC_list.append((tc_param, training_result.num_examples))

        # Aggregating TC model, local prototypes and covariances:
        print(f"Updating TC, prototypes and sigmas......")
        global_TC = aggregate(TC_list)
        set_parameters(TC_model, global_TC)

        for idx in local_sigmas_byclass.keys():
            global_sigmas[idx] = (local_sigmas_byclass[idx] / (total_freq_byclass[idx] * total_freq_byclass[idx]))
            global_protos[idx] = torch.tensor(aggregate(proto_average[idx]), device=Device, requires_grad=False)
            global_covs[idx] = torch.eye(Z_Dim, device=Device, requires_grad=False) * global_sigmas[idx]
        
        # Evaluate mean accuracy of local models on local datasets:
        print(f"Round {i+1}, evaluating......")
        time1 = time.time()
        if (i+1) % 10 == 0:
            total_acc = 0.0
            for cid in range(M):
                local_param = local_params[cid]
                _, acc = get_local_test_acc(cid, local_param, local_device=Device)
                total_acc += acc
            test_accuracies.append(total_acc / M)
            print(f"FedVTC: Round {i+1} completed, test accuracy = {total_acc / M}, time consumed = {time1-time0}")
        else:
            print("Skipping evaluation to save time...")
            print(f"FedVTC: Round {i+1} completed, time consumed = {time1-time0}")
        time0 = time1

        # Save distilled images and evaluate the global model:
        
        if i == R-1 and R >= 100:
            distilled_latent_dict = generate_random_protos_gaussian(mean=global_protos, covariance=global_covs, num_samples_per_class=Num_fake_samples)
            with torch.no_grad():
                distilled_x_dict = get_distilled_images(TC_model, distilled_latent_dict, Device)
            print(f"Dataset Distillation Done.")
            print(f"Saving Distilled images......")
            parent_path = 'distill_images'
            for idx in distilled_x_dict.keys():
                distillpath = os.path.join(parent_path, str(idx))
                if os.path.exists(distillpath):
                   pass
                else:
                    os.mkdir(distillpath)
                sid = 0 
                for sample in distilled_x_dict[idx]:
                    pixels = sample.detach().clone().cpu()
                    now = datetime.now()
                    timestamp = now.strftime("%Y%m%d%H%M")
                    imageid = timestamp+str(sid)
                    img_path = os.path.join(distillpath, imageid+'.png')
                    if os.path.exists(img_path):
                        pass
                    else:
                        save_image(pixels, img_path, normalize=True)
                    sid += 1
            print(f"distilled image saved.")
        else:
            print(f"no distill image generated for pre-training.")  
        
    for q in range(extra_rounds):
        for c in list(range(M)):
            local_param = local_params[c]
            cid = str(c)
            config={}
            training_result = local_extra_train(cid, local_param, q, c+1, config, learning_rate=lr, Btc=Batch, ft_epoch=1)
            local_params[c] = parameters_to_ndarrays(training_result.parameters)
            print(f"FedVTC EXTRA: The {c}-th local training has been completed, cid = {cid}.\n")

        time1 = time.time()
        if (q+1) % 10 == 0:
            print("FedVTC Extra: Evaluating......")
            total_acc = 0.0
            for cid in range(M):
                local_param = local_params[cid]
                _, acc = get_local_test_acc(cid, local_param, local_device=Device)
                total_acc += acc
            test_accuracies.append(total_acc / M)
            print(f"FedVTC Extra: Round {q+1} completed, test accuracy = {total_acc / M}, time consumed = {time1-time0}")
        else:
            print("Skipping evaluation to save time...")
            print(f"FedVTC Extra: Round {q+1} completed, time consumed = {time1-time0}")
        time0 = time1
        

    now = datetime.now()
    with open('results/FedVTC_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)


############################  Assistant Functions ############################

def get_distilled_images(tcmodel:TC_net, latent_dict:Dict, device='cuda'):
    X_dict = {}
    with torch.no_grad():
        for l in latent_dict.keys():
            Zs = torch.stack(latent_dict[l]).to(device)
            Vs = tcmodel(Zs)
            X_dict[l] = Vs.to(device)
    return X_dict

def get_local_test_acc(cid, local_param, local_device='cuda'):
        local_model = build_model(cid, device=local_device)
        dataset = EmnistDataset("clientdata/femnist_global_test.csv")
        testloader = DataLoader(dataset, Batch, shuffle=False)
        correct, total, loss = 0, 0, 0.0
        set_parameters(local_model, local_param)
        criterion = torch.nn.CrossEntropyLoss()
        local_model.eval()
        with torch.no_grad():
            for samples, labels in testloader:
                samples, labels = samples.to(DEVICE), labels.to(DEVICE)
                outputs, _ = local_model(samples)
                loss = criterion(outputs, labels).item() * labels.size(0)
                total += labels.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += predicted.eq(labels).sum()
        return loss/total, correct/total
