from typing import List, Optional, Tuple, Union, Dict
from models import Resnet20, freeze_layer
import random
import torch
import time
import numpy as np
from flwr.common import EvaluateRes, Metrics, Scalar, Code
from flwr.common import FitIns, FitRes, Status
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.server.strategy.aggregate import weighted_loss_avg, aggregate
from flwr.common.logger import log
from logging import WARNING
from cifar100dataset import cifar100Dataset
from Util import get_filters, get_updated_layers, set_filters, generate_subnet_SLT
from torch.utils.data import DataLoader
from datetime import datetime
from feddrop import get_mean_test_acc

CHANNELS = 3
Batch = 128
CLASSES = 100
DEVICE = "cuda"
Num_clients, Num_participants = 100, 10
ROUNDS = 500

Learnable_Params = ['fc.weight',

                    'conv4_0.weight', 'bn4_0.weight', 'bn4_0.bias',
                    'conv4_1.weight', 'bn4_1.weight', 'bn4_1.bias',
                    'conv4_2.weight', 'bn4_2.weight', 'bn4_2.bias',
                    'conv4_3.weight', 'bn4_3.weight', 'bn4_3.bias', 
                    'conv4_4.weight', 'bn4_4.weight', 'bn4_4.bias', 
                    'conv4_5.weight', 'bn4_5.weight', 'bn4_5.bias', 
                    'conv4_6.weight', 'bn4_6.weight', 'bn4_6.bias', 

                    'conv3_0.weight', 'bn3_0.weight', 'bn3_0.bias',
                    'conv3_1.weight', 'bn3_1.weight', 'bn3_1.bias', 
                    'conv3_2.weight', 'bn3_2.weight', 'bn3_2.bias', 
                    'conv3_3.weight', 'bn3_3.weight', 'bn3_3.bias', 
                    'conv3_4.weight', 'bn3_4.weight', 'bn3_4.bias',
                    'conv3_5.weight', 'bn3_5.weight', 'bn3_5.bias', 
                    'conv3_6.weight', 'bn3_6.weight', 'bn3_6.bias', 

                    'conv2_1.weight', 'bn2_1.weight', 'bn2_1.bias', 
                    'conv2_2.weight', 'bn2_2.weight', 'bn2_2.bias',
                    'conv2_3.weight', 'bn2_3.weight', 'bn2_3.bias', 
                    'conv2_4.weight', 'bn2_4.weight', 'bn2_4.bias', 
                    'conv2_5.weight', 'bn2_5.weight', 'bn2_5.bias', 
                    'conv2_6.weight', 'bn2_6.weight', 'bn2_6.bias', 
 
                    'conv1.weight', 'bn1.weight', 'bn1.bias']

def local_train(cid, params, server_round, client_count, lf, scaler, E=5, Learning_rt=0.05) -> FitRes:
    epoch = E
    print(f"Server round {server_round+1}, training on the {client_count}-th client, id = {cid}")
    dataset = cifar100Dataset("clientdata/cifar100_client_"+ str(cid) + "_ALPHA_0.1.csv")
    trainloader = DataLoader(dataset, Batch, shuffle=True)
    localmodel = Resnet20(CHANNELS, outputs=CLASSES).to(DEVICE)
    set_filters(localmodel, params)
    dropinfo, _ = generate_subnet_SLT(localmodel, scaler, lf)
    masks = mask_gradients(localmodel, dropinfo)
    SLT_freeze_filters(localmodel, masks)
    freeze_layer(localmodel, min(4, max(0, lf-1)) ) 
    time1 = time.time()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(localmodel.parameters(), lr=Learning_rt)
    localmodel.train()
    for e in range(epoch):
        for samples, labels in trainloader:
            samples, labels = samples.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = localmodel(samples)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    time2 = time.time()
    print(f"Training done, time cost = {time2-time1} seconds\n")
    layers_updated = get_updated_layers(localmodel, lf)
    status = Status(code=Code.OK, message="Success")
    return FitRes(status=status, parameters=ndarrays_to_parameters(get_filters(localmodel)), num_examples=len(dataset), metrics={"updated layer":layers_updated})

def run_SLT(M=Num_clients, P=Num_participants, R=ROUNDS, scaler=0.5, seed=10020, lr=0.05):
    global_model = Resnet20(CHANNELS, outputs=CLASSES).to(DEVICE)
    
    """ Initialize the Map of Local Models """
    localmodels = {}
    for i in range(M):
        localmodels[i] = get_filters(global_model) 
    
    time0 = time.time()
    random.seed(seed)
    test_accuracies = []
    for i in range(R):
        # Fit:
        print(f"SLT: Starting FL Round {i+1}......\n")
        fit_results = []
        clients = random.sample(list(range(M)),k=P)
        client_count = 0
        total_examples = 0
        for c in clients:
            cid = str(c)
            lf = get_SLT_stage(server_round=i, total_round=R)
            #global_parameters = get_filters(global_model)
            local_parameter = localmodels[c]
            fitres = local_train(cid, local_parameter, i, client_count+1, lf, E=5, scaler=scaler, Learning_rt=lr)
            client_count += 1
            fit_results.append((parameters_to_ndarrays(fitres.parameters), fitres.num_examples))
            total_examples += fitres.num_examples
            localmodels[c] = parameters_to_ndarrays(fitres.parameters)
        # Aggregate:
        print(f"Aggregating and updating global model.....\n")
        new_model = aggregate(fit_results)
        set_filters(global_model, new_model)

        # Evaluate:
        print(f"SLT: Round {i+1}, evaluating......")
        _, acc = get_mean_test_acc(global_model)
        test_accuracies.append(acc)
        time1 = time.time()
        print(f"Round {i+1} completed, test accuracy = {acc}, time consumed = {time1-time0}")
        time0 = time1
    
    now = datetime.now()
    with open('results/SLT_accuracies_alpha0.1_' + now.strftime("%Y%m%d%H%M") + '.txt', 'w') as fp:
        for item in test_accuracies:
            # write each item on a new line
            fp.write("%f\n" % item)

def get_SLT_stage(server_round, total_round) -> int:
    if server_round <= total_round / 5:
        return 0
    if server_round <= total_round / 5 * 2:
        return 1
    if server_round <= total_round / 5 * 3:
        return 2
    if server_round <= total_round / 5 * 4:
        return 3
    return 4

def mask_gradients(model:Resnet20, dropout_info:Dict, C=CHANNELS):
    weights = []
    params = model.state_dict()
    for k, v in params.items():
        if k in Learnable_Params:
            weights.append(v)
    if len(dropout_info) == 0:
        return [torch.ones(w.shape) for w in weights]
    last_layer_indices = list(range(C))
    Masks = []
    l = 0
    for k in dropout_info.keys():
        if k in Learnable_Params:
            non_mask_filters = dropout_info[k]
            gradient_mask = torch.ones(weights[l].shape)
            for i in range(gradient_mask.shape[0]):
                if i in non_mask_filters or k == 'fc.weight':
                    if 'conv' in k: 
                        for j in range(gradient_mask.shape[1]):
                            if not (j in last_layer_indices):
                                gradient_mask[i, j] = 0.0
                    elif k == 'fc.weight':
                        for j in range(gradient_mask.shape[1]):
                            if not (j in last_layer_indices):
                                gradient_mask[i, j] = 0.0
                else:
                    gradient_mask[i] = 0.0
            Masks.append(gradient_mask)
            last_layer_indices = non_mask_filters
            l += 1
    return Masks

def SLT_freeze_filters(model:Resnet20, masks, device=DEVICE):
    model.conv1.weight.register_hook(lambda grad: grad * masks[0].to(device))
    model.bn1.weight.register_hook(lambda grad: grad * masks[1].to(device))
    model.bn1.bias.register_hook(lambda grad: grad * masks[2].to(device))

    model.conv2_1.weight.register_hook(lambda grad: grad * masks[3].to(device))
    model.bn2_1.weight.register_hook(lambda grad: grad * masks[4].to(device))
    model.bn2_1.bias.register_hook(lambda grad: grad * masks[5].to(device))
    model.conv2_2.weight.register_hook(lambda grad: grad * masks[6].to(device))
    model.bn2_2.weight.register_hook(lambda grad: grad * masks[7].to(device))
    model.bn2_2.bias.register_hook(lambda grad: grad * masks[8].to(device))
    model.conv2_3.weight.register_hook(lambda grad: grad * masks[9].to(device))
    model.bn2_3.weight.register_hook(lambda grad: grad * masks[10].to(device))
    model.bn2_3.bias.register_hook(lambda grad: grad * masks[11].to(device))
    model.conv2_4.weight.register_hook(lambda grad: grad * masks[12].to(device))
    model.bn2_4.weight.register_hook(lambda grad: grad * masks[13].to(device))
    model.bn2_4.bias.register_hook(lambda grad: grad * masks[14].to(device))
    model.conv2_5.weight.register_hook(lambda grad: grad * masks[15].to(device))
    model.bn2_5.weight.register_hook(lambda grad: grad * masks[16].to(device))
    model.bn2_5.bias.register_hook(lambda grad: grad * masks[17].to(device))
    model.conv2_6.weight.register_hook(lambda grad: grad * masks[18].to(device))
    model.bn2_6.weight.register_hook(lambda grad: grad * masks[19].to(device))
    model.bn2_6.bias.register_hook(lambda grad: grad * masks[20].to(device))

    model.conv3_0.weight.register_hook(lambda grad: grad * masks[21].to(device))
    model.bn3_0.weight.register_hook(lambda grad: grad * masks[22].to(device))
    model.bn3_0.bias.register_hook(lambda grad: grad * masks[23].to(device))
    model.conv3_1.weight.register_hook(lambda grad: grad * masks[24].to(device))
    model.bn3_1.weight.register_hook(lambda grad: grad * masks[25].to(device))
    model.bn3_1.bias.register_hook(lambda grad: grad * masks[26].to(device))
    model.conv3_2.weight.register_hook(lambda grad: grad * masks[27].to(device))
    model.bn3_2.weight.register_hook(lambda grad: grad * masks[28].to(device))
    model.bn3_2.bias.register_hook(lambda grad: grad * masks[29].to(device))
    model.conv3_3.weight.register_hook(lambda grad: grad * masks[30].to(device))
    model.bn3_3.weight.register_hook(lambda grad: grad * masks[31].to(device))
    model.bn3_3.bias.register_hook(lambda grad: grad * masks[32].to(device))
    model.conv3_4.weight.register_hook(lambda grad: grad * masks[33].to(device))
    model.bn3_4.weight.register_hook(lambda grad: grad * masks[34].to(device))
    model.bn3_4.bias.register_hook(lambda grad: grad * masks[35].to(device))
    model.conv3_5.weight.register_hook(lambda grad: grad * masks[36].to(device))
    model.bn3_5.weight.register_hook(lambda grad: grad * masks[37].to(device))
    model.bn3_5.bias.register_hook(lambda grad: grad * masks[38].to(device))
    model.conv3_6.weight.register_hook(lambda grad: grad * masks[39].to(device))
    model.bn3_6.weight.register_hook(lambda grad: grad * masks[40].to(device))
    model.bn3_6.bias.register_hook(lambda grad: grad * masks[41].to(device))

    model.conv4_0.weight.register_hook(lambda grad: grad * masks[42].to(device))
    model.bn4_0.weight.register_hook(lambda grad: grad * masks[43].to(device))
    model.bn4_0.bias.register_hook(lambda grad: grad * masks[44].to(device))
    model.conv4_1.weight.register_hook(lambda grad: grad * masks[45].to(device))
    model.bn4_1.weight.register_hook(lambda grad: grad * masks[46].to(device))
    model.bn4_1.bias.register_hook(lambda grad: grad * masks[47].to(device))
    model.conv4_2.weight.register_hook(lambda grad: grad * masks[48].to(device))
    model.bn4_2.weight.register_hook(lambda grad: grad * masks[49].to(device))
    model.bn4_2.bias.register_hook(lambda grad: grad * masks[50].to(device))
    model.conv4_3.weight.register_hook(lambda grad: grad * masks[51].to(device))
    model.bn4_3.weight.register_hook(lambda grad: grad * masks[52].to(device))
    model.bn4_3.bias.register_hook(lambda grad: grad * masks[53].to(device))
    model.conv4_4.weight.register_hook(lambda grad: grad * masks[54].to(device))
    model.bn4_4.weight.register_hook(lambda grad: grad * masks[55].to(device))
    model.bn4_4.bias.register_hook(lambda grad: grad * masks[56].to(device))
    model.conv4_5.weight.register_hook(lambda grad: grad * masks[57].to(device))
    model.bn4_5.weight.register_hook(lambda grad: grad * masks[58].to(device))
    model.bn4_5.bias.register_hook(lambda grad: grad * masks[59].to(device))
    model.conv4_6.weight.register_hook(lambda grad: grad * masks[60].to(device))
    model.bn4_6.weight.register_hook(lambda grad: grad * masks[61].to(device))
    model.bn4_6.bias.register_hook(lambda grad: grad * masks[62].to(device))

    model.fc.weight.register_hook(lambda grad: grad * masks[63].to(device))