from typing import List, Optional, Tuple, Union, Dict
from models import AlexNet
import flwr as fl
import random
import torch
import time
from flwr.common import EvaluateRes, Metrics, Scalar, Code
from flwr.common import FitIns, FitRes, Status
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from cifardataset import cifar10Dataset
from util import get_filters, get_parameters, set_filters, dropout_aggregation, generate_subnet_ordered
from torch.utils.data import DataLoader, random_split
from feddrop import get_mean_test_acc
from datetime import datetime

CHANNEL = 3
Batch = 128
CLASSES = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Num_clients, Num_participants = 100, 10
ROUNDS = 500

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
  # Multiply accuracy of each client by number of examples used
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]

  # Aggregate and return custom metric (weighted average)
  return {"accuracy": sum(accuracies) / sum(examples)}

def get_rate(cid):
    z = [0.2, 0.4, 0.6, 0.8, 1.0]
    return z[int(cid) % 5]

def local_train(cid, rate, sub_params, server_round, drop_info, client_count, E=5) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}\n")
    dataset = cifar10Dataset("clientdata/cifar10_client_"+ str(cid) + "_ALPHA_0.1.csv")
    trainloader = DataLoader(dataset, Batch, shuffle=True)
    submodel = AlexNet(rate=rate).to(DEVICE)
    set_filters(submodel, sub_params)
    time1 = time.time()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(submodel.parameters(), lr=0.005)
    submodel.train()
    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = submodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    parameters_updated = get_filters(submodel)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(trainloader), metrics={"drop_info":drop_info})

def run_FL_fjord(M=Num_clients, P=Num_participants, R=ROUNDS, seed=2024):
    global_model = AlexNet(CHANNEL, outputs=CLASSES).to(DEVICE)
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")
        fit_results = []
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        for c in clients:
            cid = str(c)
            drop_rate = get_rate(cid)
            drop_info, sub_parameters = generate_subnet_ordered(global_model, drop_rate)
            fitres = local_train(cid, drop_rate, sub_parameters, i, drop_info, client_count+1,E=5)
            fit_results.append(fitres)
            client_count += 1

        # Aggregate:
        print(f"Aggregating and updating global model.....\n")
        current_parameter = get_filters(global_model)
        aggregated_parameters = dropout_aggregation(fit_results, current_parameter)
        set_filters(global_model, aggregated_parameters)

        # Evaluate:
        print(f"Round {i+1}, evaluating......")
        _, acc = get_mean_test_acc(global_model)
        test_accuracies.append(acc)
        time1 = time.time()
        print(f"Round {i+1} completed, test acc = {acc} time consumed = {time1-time0}")
        time0 = time1
    
    now = datetime.now()
    with open('results/fjord_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)