from typing import List, Optional, Tuple, Union, Dict
from models import AlexNet
import random
import torch
import time
import numpy as np
from flwr.common import EvaluateRes, Metrics, Scalar, Code
from flwr.common import FitIns, FitRes, Status
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from cifardataset import cifar10Dataset
from util import get_filters, get_updated_layers, set_filters, generate_subnet_SLT
from torch.utils.data import DataLoader
from datetime import datetime
from feddrop import get_mean_test_acc

CHANNELS = 3
Batch = 128
CLASSES = 10
DEVICE = "cuda"
Num_clients, Num_participants = 100, 10
ROUNDS = 500

Learnable_Params = ['conv1.weight', 'conv1.bias', 'bn1.weight', 'bn1.bias', 
                    'conv2.weight', 'conv2.bias', 'bn2.weight', 'bn2.bias',
                    'conv3.weight', 'conv3.bias', 'bn3.weight', 'bn3.bias',
                    'conv4.weight', 'conv4.bias', 'bn4.weight', 'bn4.bias',
                    'conv5.weight', 'conv5.bias', 'bn5.weight', 'bn5.bias',
                    'fc.weight', 'fc.bias', 'fc1.weight', 'fc2.weight']

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
  # Multiply accuracy of each client by number of examples used
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]

  # Aggregate and return custom metric (weighted average)
  return {"accuracy": sum(accuracies) / sum(examples)}

def get_SLT_stage(server_round, total_round) -> int:
    if server_round <= total_round / 6:
        return 0
    if server_round <= total_round / 6 * 2:
        return 1
    if server_round <= total_round / 6 * 3:
        return 2
    if server_round <= total_round / 6 * 4:
        return 3
    if server_round <= total_round / 6 * 5:
        return 4
    return 5 

def mask_gradients(model:AlexNet, dropout_info:Dict, C=CHANNELS):
    weights = []
    params = model.state_dict()
    for k, v in params.items():
        if k in Learnable_Params:
            weights.append(v)
    if len(dropout_info) == 0:
        return [torch.ones(w.shape) for w in weights]
    last_layer_indices = list(range(C))
    Masks = []
    l = 0
    for k in dropout_info.keys():
        if k == 'fc.bias':
            Masks.append(torch.ones(weights[l].shape))
        elif k in Learnable_Params:
            non_mask_filters = dropout_info[k]
            gradient_mask = torch.ones(weights[l].shape)
            for i in range(gradient_mask.shape[0]):
                if i in non_mask_filters or k == 'fc.weight':
                    if k in ['conv1.weight', 'conv2.weight', 'conv3.weight', 'conv4.weight', 'conv5.weight', 'fc2.weight']: 
                        for j in range(gradient_mask.shape[1]):
                            if not (j in last_layer_indices):
                                gradient_mask[i, j] = 0.0
                    elif k == 'fc1.weight':
                        old_indices = []
                        for j_ in last_layer_indices:
                            for q in range(j_*7*7, (j_+1)*7*7):
                                old_indices.append(q)
                        for j in range(gradient_mask.shape[1]):
                            if not (j in last_layer_indices):
                                gradient_mask[i, j] = 0.0
                    elif k == 'fc.weight':
                        for j in range(gradient_mask.shape[1]):
                            if not (j in last_layer_indices):
                                gradient_mask[i, j] = 0.0
                else:
                    gradient_mask[i] = 0.0
            Masks.append(gradient_mask)
            last_layer_indices = non_mask_filters
            l += 1
    return Masks

def SLT_freeze_filters(model:AlexNet, masks, device=DEVICE):
    model.conv1.weight.register_hook(lambda grad: grad * masks[0].to(device))
    model.conv1.bias.register_hook(lambda grad: grad * masks[1].to(device))
    model.bn1.weight.register_hook(lambda grad: grad * masks[2].to(device))
    model.bn1.bias.register_hook(lambda grad: grad * masks[3].to(device))
    model.conv2.weight.register_hook(lambda grad: grad * masks[4].to(device))
    model.conv2.bias.register_hook(lambda grad: grad * masks[5].to(device))
    model.bn2.weight.register_hook(lambda grad: grad * masks[6].to(device))
    model.bn2.bias.register_hook(lambda grad: grad * masks[7].to(device))
    model.conv3.weight.register_hook(lambda grad: grad * masks[8].to(device))
    model.conv3.bias.register_hook(lambda grad: grad * masks[9].to(device))
    model.bn3.weight.register_hook(lambda grad: grad * masks[10].to(device))
    model.bn3.bias.register_hook(lambda grad: grad * masks[11].to(device))
    model.conv4.weight.register_hook(lambda grad: grad * masks[12].to(device))
    model.conv4.bias.register_hook(lambda grad: grad * masks[13].to(device))
    model.bn4.weight.register_hook(lambda grad: grad * masks[14].to(device))
    model.bn4.bias.register_hook(lambda grad: grad * masks[15].to(device))
    model.conv5.weight.register_hook(lambda grad: grad * masks[16].to(device))
    model.conv5.bias.register_hook(lambda grad: grad * masks[17].to(device))
    model.bn5.weight.register_hook(lambda grad: grad * masks[18].to(device))
    model.bn5.bias.register_hook(lambda grad: grad * masks[19].to(device))
    model.fc1.weight.register_hook(lambda grad: grad * masks[20].to(device))
    model.fc2.weight.register_hook(lambda grad: grad * masks[21].to(device))
    model.fc.weight.register_hook(lambda grad: grad * masks[22].to(device))
    model.fc.bias.register_hook(lambda grad: grad * masks[23].to(device))

def local_train(cid, params, server_round, client_count, lf, scaler, E=5) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = cifar10Dataset("clientdata/cifar10_client_"+ str(cid) + "_ALPHA_0.1.csv")
    trainloader = DataLoader(dataset, Batch, shuffle=True)
    localmodel = AlexNet(CHANNELS, outputs=CLASSES).to(DEVICE)
    set_filters(localmodel, params)
    dropinfo, _ = generate_subnet_SLT(localmodel, scaler, lf)
    masks = mask_gradients(localmodel, dropinfo)
    SLT_freeze_filters(localmodel, masks)
    time1 = time.time()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(localmodel.parameters(), lr=0.005)
    localmodel.train()
    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = localmodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    layers_updated = get_updated_layers(localmodel, lf)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(get_filters(localmodel)), num_examples=len(dataset), metrics={"updated layer":layers_updated})

def run_SLT(M=Num_clients, P=Num_participants, R=ROUNDS, scaler=0.5, seed=10020, lr=0.005):
    global_model = AlexNet(CHANNELS, outputs=CLASSES).to(DEVICE)
    
    """ Initialize the Map of Local Models """
    localmodels = {}
    for i in range(M):
        localmodels[i] = get_filters(global_model) 
    
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    for i in range(R):
        # Fit:
        print(f"SLT: Starting FL Round {i+1}......\n")
        fit_results = []
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        total_examples = 0
        for c in clients:
            cid = str(c)
            lf = get_SLT_stage(server_round=i, total_round=R)
            #global_parameters = get_filters(global_model)
            local_parameter = localmodels[c]
            fitres = local_train(cid, local_parameter, i, client_count+1, lf, E=5, scaler=scaler)
            client_count += 1
            fit_results.append((parameters_to_ndarrays(fitres.parameters), fitres.num_examples))
            total_examples += fitres.num_examples
            localmodels[c] = parameters_to_ndarrays(fitres.parameters)
        # Aggregate:
        print(f"Aggregating and updating global model.....\n")
        timex = time.time()
        new_model = aggregate(fit_results)
        timey = time.time()
        set_filters(global_model, new_model)
        print(f"time for aggregation = {timey-timex}")
        # Evaluate:
        print(f"SLT: Round {i+1}, evaluating......")
        _, acc = get_mean_test_acc(global_model)
        test_accuracies.append(acc)
        time1 = time.time()
        print(f"Round {i+1} completed, test accuracy = {acc}, time consumed = {time1-time0}")
        time0 = time1
    
    now = datetime.now()
    with open('results/SLT_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)

    