from typing import List, Optional, Tuple, Union, Dict
from models import Resnet20, freeze_layer
import random
import torch
import time
import numpy as np
from flwr.common import EvaluateRes, Metrics, Scalar, Code
from flwr.common import FitIns, FitRes, Status
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from dataset import cinicDataset
from Util import get_filters, get_updated_layers, set_filters, get_mean_test_acc
from torch.utils.data import DataLoader
from datetime import datetime

CHANNEL = 3
Batch = 128
CLASSES = 10
DEVICE = "cuda"
Num_clients, Num_participants = 100, 10
ROUNDS = 500
LR = 0.01

def get_lf_number(cid):
   z = [4,3,2,1,0]
   return z[int(cid) % 5]

def aggregate_updated_layer(model:Resnet20, results:Tuple[Dict,int]) -> Dict:
       # For frozen layers that do not get updated, use the old value:
       global_dict = {}
       for k, v in model.state_dict().items():
          global_dict[k] = v
    
       # Aggregate layers and update the global model:
       received_layer_dict = {} # key:str, value:list[(np.ndarray],num)]
       for local_dict, num in results:
          for k,v in local_dict.items():
             if k in received_layer_dict.keys():
                received_layer_dict[k].append((v,num))
             else:
                received_layer_dict[k] = [(v,num)]
       for k in received_layer_dict.keys():
          global_dict[k] = torch.tensor(aggregate(received_layer_dict[k]))
          
       # Return the updated global model (with the format of dict):   
       return global_dict

def local_train(cid, params, server_round, client_count, lf, E=5, lr=0.05) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = cinicDataset("clientdata/cinic_client_"+ str(cid) + "_ALPHA_0.1.csv")
    trainloader = DataLoader(dataset, Batch, shuffle=True)
    localmodel = Resnet20(CHANNEL, outputs=CLASSES).to(DEVICE)
    set_filters(localmodel, params)
    freeze_layer(localmodel, lf)
    time1 = time.time()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(localmodel.parameters(), lr=lr)
    localmodel.train()
    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = localmodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    layers_updated = get_updated_layers(localmodel, lf)
    #sparcified_layer, weights = sparcify_layer(layers_updated, rate=0.75, num_examples=len(dataset))
    status = Status(code=Code.OK, message="Success")
    #return FitRes(status=status, parameters=ndarrays_to_parameters(layers_updated), num_examples=len(dataset), metrics={"updated layer":sparcified_layer, "weights":weights})
    return FitRes(status=status, parameters=ndarrays_to_parameters(layers_updated), num_examples=len(dataset), metrics={"updated layer":layers_updated})

def run_FL_FedLF(M=Num_clients, P=Num_participants, R=ROUNDS, seed=2024, lr=LR):
    global_model = Resnet20(CHANNEL, outputs=CLASSES).to(DEVICE)
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")
        fit_results = []
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        total_examples = 0
        for c in clients:
            cid = str(c)
            lf = get_lf_number(cid)
            gloabl_parameters = get_filters(global_model)
            fitres = local_train(cid, gloabl_parameters, i, client_count+1, lf, E=5, lr=lr)
            client_count += 1
            #fit_results.append((fitres.metrics["updated layer"], fitres.metrics["weights"], fitres.num_examples))
            fit_results.append((fitres.metrics["updated layer"], fitres.num_examples))
            total_examples += fitres.num_examples
        # Aggregate:
        print(f"Aggregating and updating global model.....\n")
        #new_model_dict = aggregate_updated_layer(global_model, fit_results, total_examples)
        new_model_dict = aggregate_updated_layer(global_model, fit_results)
        global_model.load_state_dict(new_model_dict, strict=False)

        # Evaluate:
        print(f"Round {i+1}, evaluating......")
        _, acc = get_mean_test_acc(global_model)
        test_accuracies.append(acc)
        time1 = time.time()
        print(f"Round {i+1} completed, test accuracy = {acc}, time consumed = {time1-time0}")
        time0 = time1
    
    now = datetime.now()
    with open('results/FedLF_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)