# FLower:
import flwr as fl
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common import Code, EvaluateIns, EvaluateRes, FitRes, Status
# other dependecies:
from models import CNN
import torch
from torch.utils.data import DataLoader, random_split
from typing import Dict
from util import set_filters, get_filters
from flwr.common import Code, EvaluateIns, EvaluateRes, FitIns, FitRes, Status
from typing import List
from copy import deepcopy
import numpy as np

DEVICE = torch.device('cpu')
CLASSES = 62
CHANNELS = 1
Learnable_Params = ['conv1.weight', 'conv1.bias', 'bn1.weight', 'bn1.bias', 
                    'conv2.weight', 'conv2.bias', 'bn2.weight', 'bn2.bias',
                    'fc.weight', 'fc.bias']

class STL_client(fl.client.Client):
    def __init__(self, cid, dataset, epoch, batch, rate):
        self.cid = cid
        self.model = CNN(in_channels=CHANNELS, outputs=CLASSES).to(DEVICE)
        self.testmodel = CNN(in_channels=CHANNELS, outputs=CLASSES).to(DEVICE)
        self.local_epoch = epoch
        self.local_batch_size= batch
        len_train = int(len(dataset) * 0.7)
        len_test = len(dataset) - len_train
        ds_train, ds_val = random_split(dataset, [len_train, len_test], torch.Generator().manual_seed(2024))
        self.trainloader = DataLoader(ds_train, self.local_batch_size, shuffle=True)
        self.testloader = DataLoader(ds_val, self.local_batch_size, shuffle=False)
        self.rate=rate
    
    def fit(self, ins: FitIns) -> FitRes:
        # Deserialize parameters to NumPy ndarray's
        drop_info = ins.config['drop_info']
        personal_model = ins.config['personal model']
        set_filters(self.model, personal_model)
        
        if drop_info == None:
            self.model.conv1.weight.requires_grad_(False)
            self.model.conv1.bias.requires_grad_(False)
            self.model.bn1.weight.requires_grad_(False)
            self.model.bn1.bias.requires_grad_(False)
        else:
            masks = mask_gradients(self.model, drop_info)
            freeze_filters(self.model, masks)

        self.train()
        # Serialize ndarray's into a Parameters object
        new_local_model = get_filters(self.model)

        status = Status(code=Code.OK, message="Success")
        return FitRes(status=status, parameters=ndarrays_to_parameters(new_local_model), num_examples=len(self.trainloader), metrics={})
    
    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        set_filters(self.testmodel, parameters_to_ndarrays(parameters_original))
        loss, accuracy = self.test() # return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.testloader),
            metrics={"accuracy": float(accuracy)},
        )
    
    def train(self):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-4)
        self.model.train()
        for e in range(self.local_epoch):
            for samples, labels in self.trainloader:
                samples, labels = samples.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad()
                outputs = self.model(samples)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

    def test(self):
        """Evaluate the network on the entire test set."""
        criterion = torch.nn.CrossEntropyLoss()
        correct, total, loss = 0, 0, 0.0
        self.testmodel.eval()
        with torch.no_grad():
            for samples, labels in self.testloader:
                samples, labels = samples.to(DEVICE), labels.to(DEVICE)
                outputs = self.testmodel(samples)
                loss = criterion(outputs, labels).item() * labels.size(0)
                total += labels.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += predicted.eq(labels).sum()
        loss = loss / total
        accuracy = correct / total
        return loss, accuracy

def mask_gradients(model:CNN, dropout_info:Dict, C=CHANNELS):
    weights = []
    params = model.state_dict()
    for k, v in params.items():
        if k in Learnable_Params:
            weights.append(v)
    if len(dropout_info) == 0:
        return [torch.ones(w.shape) for w in weights]
    last_layer_indices = list(range(C))
    Masks = []
    l = 0
    for k in dropout_info.keys():
        if k in Learnable_Params:
            non_mask_filters = dropout_info[k]
            gradient_mask = torch.ones(weights[l].shape)
            for i in range(gradient_mask.shape[0]):
                if i in non_mask_filters or k == 'fc.weight':
                    if k == 'conv1.weight' or k == 'conv2.weight':                       
                        for j in range(gradient_mask.shape[1]):
                            if not (j in last_layer_indices):
                                gradient_mask[i, j] = 0.0
                else:
                    gradient_mask[i] = 0.0
            Masks.append(gradient_mask)
            last_layer_indices = non_mask_filters
            l += 1
    return Masks

def freeze_filters(model:CNN, masks, device=DEVICE):
    model.conv1.weight.register_hook(lambda grad: grad * (masks[0].to(device)))
    model.conv1.bias.register_hook(lambda grad: grad * (masks[1].to(device)))
    model.bn1.weight.register_hook(lambda grad: grad * (masks[2].to(device)))
    model.bn1.bias.register_hook(lambda grad: grad * (masks[3].to(device)))
    model.conv2.weight.register_hook(lambda grad: grad * (masks[4].to(device)))
    model.conv2.bias.register_hook(lambda grad: grad * (masks[5].to(device)))
    model.bn2.weight.register_hook(lambda grad: grad * (masks[6].to(device)))
    model.bn2.bias.register_hook(lambda grad: grad * (masks[7].to(device)))
    model.fc.weight.register_hook(lambda grad: grad * (masks[8].to(device)))
    model.fc.bias.register_hook(lambda grad: grad * (masks[9].to(device)))