from typing import List, Optional, Tuple, Union, Dict
from models import AlexNet
import flwr as fl
import random
import torch
import time
from flwr.common import EvaluateRes, Metrics, Scalar, Code
from flwr.common import FitIns, FitRes, Status
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from cifardataset import cifar10Dataset
from util import get_filters, get_parameters, set_filters, dropout_aggregation, generate_filters_random
from torch.utils.data import DataLoader, random_split
from datetime import datetime

CHANNEL = 3
Batch = 128
CLASSES = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Num_clients, Num_participants = 100, 10
ROUNDS = 500

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
  # Multiply accuracy of each client by number of examples used
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]

  # Aggregate and return custom metric (weighted average)
  return {"accuracy": sum(accuracies) / sum(examples)}

def get_rate(cid):
    z = [0.2, 0.4, 0.6, 0.8, 1.0]
    return z[int(cid) % 5]

def local_train(cid, rate, sub_params, server_round, drop_info, client_count, lR=0.005, E=5) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}\n")
    dataset = cifar10Dataset("clientdata/cifar10_client_"+ str(cid) + "_ALPHA_0.1.csv")
    trainloader = DataLoader(dataset, Batch, shuffle=True)
    submodel = AlexNet(rate=rate).to(DEVICE)
    set_filters(submodel, sub_params)
    time1 = time.time()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(submodel.parameters(), lr=lR)
    submodel.train()
    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = submodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    parameters_updated = get_filters(submodel)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(parameters_updated), num_examples=len(trainloader), metrics={"drop_info":drop_info})

def get_mean_test_acc(global_model:AlexNet):
        dataset = cifar10Dataset("clientdata/cifar10_global_test.csv")
        testloader = DataLoader(dataset, Batch, shuffle=False)
        criterion = torch.nn.CrossEntropyLoss()
        correct, total, loss = 0, 0, 0.0
        global_model.eval()
        with torch.no_grad():
            for samples, labels in testloader:
                samples, labels = samples.to(DEVICE), labels.to(DEVICE)
                outputs = global_model(samples)
                loss += criterion(outputs, labels).item()
                total += labels.size(0)
                #correct += (outputs == labels).sum()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum()
        loss /= len(dataset)
        accuracy = correct / total
        return loss, accuracy

def run_FL_dropout(M=Num_clients, P=Num_participants, R=ROUNDS, seed=2024, lr=0.005):
    global_model = AlexNet(CHANNEL, outputs=CLASSES).to(DEVICE)
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    
    for i in range(R):

        # Fit:
        print(f"Starting FL Round {i+1}......\n")
        fit_results = []
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        for c in clients:
            cid = str(c)
            drop_rate = get_rate(cid)
            drop_info, sub_parameters = generate_filters_random(global_model, drop_rate)
            fitres = local_train(cid, drop_rate, sub_parameters, i, drop_info, client_count+1,lR=lr, E=5)
            fit_results.append(fitres)
            client_count += 1

        # Aggregate:
        print(f"Aggregating and updating global model.....\n")
        current_parameter = get_filters(global_model)

        timex = time.time()
        aggregated_parameters = dropout_aggregation(fit_results, current_parameter)
        timey = time.time()
        print(f"time for aggregation = {timey-timex}")
        set_filters(global_model, aggregated_parameters)

        # Evaluate:
        print(f"Round {i+1}, evaluating......")
        _, acc = get_mean_test_acc(global_model)
        test_accuracies.append(acc)
        time1 = time.time()
        print(f"Round {i+1} completed, test acc = {acc}, time consumed = {time1-time0}")
        time0 = time1
    
    now = datetime.now()
    with open('results/feddrop_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)
