from typing import Dict
import random
import torch
import time
from flwr.common import parameters_to_ndarrays
from flwr.server.strategy.aggregate import aggregate
from util import get_parameters
from datetime import datetime
from FedVTC import get_local_test_acc
from fedproto_local import local_train
from hyper_params import Z_Dim, CLASSES, DEVICE
from resnet import build_model

Batch = 16
Num_clients, Num_participants = 100, 10
ROUNDS = 100

def run_FedTGP(M=Num_clients, P=Num_participants, R=ROUNDS, Device=DEVICE, 
              seed=2024, lr=0.1, global_epoch=10, local_epoch=5, Batch=16):
    
    # Initialization:

    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    local_params = {}
    global_protos = {}
    
    for l in range(CLASSES):
        global_protos[l] = torch.randn(Z_Dim, device=Device, requires_grad=False)
    
    for cid in range(M):
        localmodel = build_model(cid, device=Device)
        local_params[cid] = get_parameters(localmodel)
        del localmodel
    
    # Start FL global iteration:
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")

        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        proto_average = {}

        for c in clients:

            local_param = local_params[c]
            cid = str(c)

            if i == 0:
                training_result = local_train(cid, local_param, i, client_count+1, E=local_epoch, learning_rate=lr, Btc=Batch)
            else:
                training_result = local_train(cid, local_param, i, client_count+1, E=local_epoch, learning_rate=lr, Btc=Batch, proto_Z=global_protos)

            new_local_param = parameters_to_ndarrays(training_result.parameters)
            local_proto = training_result.metrics["proto"]
            frequencies = training_result.metrics["count"]
            local_params[c] = new_local_param

            # Get the total number of samples, prototypes and covariances (in Dict[Tuple] format) of each class on each client:
            for label in local_proto.keys():
                if label in proto_average.keys():
                    proto_average[label].append((local_proto[label].cpu().numpy(), frequencies[label]))
                else:
                    proto_average[label] = [(local_proto[label].cpu().numpy(), frequencies[label])]

            client_count += 1
            print(f"The {client_count}-th local training has been completed, cid = {cid}.\n")
        
        # Aggregating central prototypes:
        for idx in proto_average.keys():
            proto_average[idx] = torch.tensor(aggregate(proto_average[idx]), device=Device, requires_grad=False)

        new_global_protos = discritiminative_optimization(central_protos=global_protos, proto_average=proto_average, plr=lr, ge=global_epoch)
        for idx in new_global_protos.keys():
            global_protos[idx] = new_global_protos[idx]

        print(f"Central prototypes optimization completed.")

            
        # Evaluate mean accuracy of local models on local datasets:
        print(f"Round {i+1}, evaluating......")
        time1 = time.time()
        if (i+1) % 10 == 0:
            total_acc = 0.0
            for cid in range(M):
                local_param = local_params[cid]
                _, acc = get_local_test_acc(cid, local_param, local_device=Device)
                total_acc += acc
            test_accuracies.append(total_acc / M)
            print(f"FedTGP: Round {i+1} completed, test accuracy = {total_acc / M}, time consumed = {time1-time0}")
        else:
            print("Skipping evaluation to save time...")
            print(f"FedTGP: Round {i+1} completed, time consumed = {time1-time0}")
        time0 = time1

    now = datetime.now()
    with open('results/FedTGP_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)


def discritiminative_optimization(central_protos:Dict, proto_average:Dict, plr=0.001, tau=100, ge=10):

    updated_Zs = {}
    
    # Obtain maximum cluster center margin:
    max_dist = 0.0
    with torch.no_grad():
        for l1 in proto_average.keys():
            for l2 in proto_average.keys():
                if l1 != l2:
                    dist = torch.norm(proto_average[l1]-proto_average[l2])
                    if dist.item() > max_dist:
                        max_dist = dist.item()
    delta = min(max_dist, tau)
    print(f"Delta(t) = {delta}")

    for idx in proto_average.keys():
      
        c = torch.randn(proto_average[idx].shape, requires_grad=True, device='cuda')
        p_optimizer = torch.optim.Adam([c], lr=plr, betas=(0.5, 0.999))
    
        for _ in range(ge):
            p_optimizer.zero_grad()
            outside_class = []
            with torch.no_grad():
                c_in = proto_average[idx]
                for idx_ in central_protos.keys():
                    if idx_ != idx:
                        if idx_ in proto_average.keys():
                            outside_class.append(proto_average[idx_])
                        else:
                            outside_class.append(central_protos[idx_])
            d1 = torch.norm(c-c_in)
            l1 = torch.exp(-(d1+delta))
            l2 = 0.0
            for c_out in outside_class:
                d2 = torch.norm(c-c_out)
                l2 += torch.exp(-d2)
            loss = -torch.log(l1/(l1+l2))
            if torch.isnan(loss):
                print(f"nan detected, prototype learning stopped......")
                return central_protos
            loss.backward()
            p_optimizer.step()

        with torch.no_grad():
            updated_Zs[idx] = c.detach().clone()

    return updated_Zs
