from typing import Dict
from resnet import build_model
import random
import torch
import time
from flwr.common import parameters_to_ndarrays
from flwr.server.strategy.aggregate import aggregate
from util import get_filters, generate_random_protos_gaussian, get_parameters
from proto_operators import get_global_cov
from CCVR_local import local_train
from hyper_params import Z_Dim, CHANNEL, CLASSES, DEVICE, Batch
from generator import Predictor
from FedVTC import get_local_test_acc
from dataset import latentDataset
from torch.utils.data import DataLoader
from datetime import datetime

Num_clients, Num_participants = 100, 10
ROUNDS = 100
NC = CHANNEL
Num_fake_samples = 1000


def run_ccvr(M=Num_clients, P=Num_participants, R=ROUNDS, Device=DEVICE, 
              distill_round=10, seed=2024, lr=0.1, local_epoch=5, distill_lr=1e-4, num_samples_per_class=NC, Batch=16):
    
    # Initialization:
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    local_params = {}
    global_protos = {}
    global_covs = {}
    global_predictor = Predictor(Z_Dim, CLASSES).to(Device)
    
    for l in range(CLASSES):
       
        global_protos[l], global_covs[l] = [], []
        for _ in range(num_samples_per_class):
            global_protos[l].append(torch.randn(20*7*7, device=Device, requires_grad=True))
            global_covs[l].append(torch.eye(20*7*7, device=Device, requires_grad=False))
    
    for cid in range(M):
        localmodel = build_model(cid, device=Device)
        local_params[cid] = get_filters(localmodel)
        del localmodel
    
    # Start FL global iteration:
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        num_proto_cov = {}
        proto_average = {}
        for c in clients:

            local_param = local_params[c]
            cid = str(c)

            if i == 0:
                training_result = local_train(cid, local_param, i, client_count+1, E=local_epoch, learning_rate=lr, Btc=Batch)
            else:
                classifier = get_parameters(global_predictor)
                training_result = local_train(cid, local_param, i, client_count+1, E=local_epoch, learning_rate=lr, Btc=Batch, classifier_weights=classifier)
            new_local_param = parameters_to_ndarrays(training_result.parameters)
            local_proto = training_result.metrics["proto"]
            local_cov = training_result.metrics["cov"]
            frequencies = training_result.metrics["count"]
            
            # Get the total number of samples, prototypes and covariances (in Dict[Tuple] format) of each class on each client:
            for label in local_proto.keys():
                if label in num_proto_cov.keys():
                    num_proto_cov[label].append((frequencies[label], local_proto[label], local_cov[label]))
                    proto_average[label].append((local_proto[label].cpu().numpy(), frequencies[label]))
                else:
                    num_proto_cov[label] = [(frequencies[label], local_proto[label], local_cov[label])]
                    proto_average[label] = [(local_proto[label].cpu().numpy(), frequencies[label])]
            
            local_params[c] = new_local_param
            client_count += 1
            print(f"The {client_count}-th local training has been completed, cid = {cid}.\n")
   
        # Aggregating central prototypes:
        for idx in proto_average.keys():
            global_protos[idx] = torch.tensor(aggregate(proto_average[idx]), device=Device, requires_grad=False)
        print("Global prototype updated.")        
        
        # Update global covariances:
        new_covs = get_global_cov(num_proto_cov)
        for l in new_covs.keys():
            global_covs[l] = new_covs[l]
        
        distilled_latent_dict = generate_random_protos_gaussian(mean=global_protos, covariance=global_covs, num_samples_per_class=Num_fake_samples)
        train_classifier(global_predictor, distilled_latent_dict, global_epoch=distill_round, distill_lr=distill_lr, btc=Batch)
        print(f"Classifier updated.")

        # Evaluate mean accuracy of local models on local datasets:
        print(f"Round {i+1}, evaluating......")
        time1 = time.time()
        if (i+1) % 10 == 0:
            total_acc = 0.0
            for cid in range(M):
                local_param = local_params[cid]
                _, acc = get_local_test_acc(cid, local_param, local_device=Device)
                total_acc += acc
            test_accuracies.append(total_acc / M)
            print(f"CCVR: Round {i+1} completed, test accuracy = {total_acc / M}, time consumed = {time1-time0}")
        else:
            print("Skipping evaluation to save time...")
            print(f"CCVR: Round {i+1} completed, time consumed = {time1-time0}")
        time0 = time1

    now = datetime.now()
    with open('results/CCVR_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)


def train_classifier(model:Predictor, latent_dict:Dict, global_epoch=10, distill_lr=0.001, btc=Batch):
    dataset = latentDataset(latent_dict)
    dataloader = DataLoader(dataset, btc, shuffle=True)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=distill_lr, betas=(0.5, 0.999))
    for _ in range(global_epoch):
        for samples, labels in dataloader:
            optimizer.zero_grad()
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            y = model(samples)
            loss = criterion(y, labels)
            loss.backward()
            optimizer.step()