from abc import ABC, abstractmethod
import torch
import sys
import numpy as np

import copy
import tqdm
import json
import random
import logging

from sklearn.metrics import accuracy_score, f1_score

class FLDeviceBase(ABC):

    def __init__ (self, device_id, storage_path):
        self._device_id = device_id
        self._storage_path = storage_path

        #model related
        self._model = None
        self._model_kwargs = None
        self._model_class = None

        #data related
        self._test_data = None
        self._train_data = None
        self._batch_size_test = 1024
        self._batch_size_train = 32

        #training related
        self._optimizer = None
        self._optimizer_args = None
        self._loss_F = None
        self.lr = -1.0
        self.resources = 1.0

        self._torch_device = None
        self._accuracy_test = None
        self.is_unbalanced = False


    def get_device_id(self):
        return self._device_id

    def set_seed(self, seed):
        torch.manual_seed(seed)

    def set_model(self, model, **kwargs):
        self._model_kwargs = kwargs
        self._model_class = model

    def init_model(self):
        self._model = self._model_class(**self._model_kwargs)

    def del_model(self):
        self._model = None

    def set_test_data(self, dataset):
        self._test_data = dataset
        return

    def set_train_data(self, dataset):
        self._train_data = dataset
        return

    def set_torch_device(self, torch_dev):
        self._torch_device = torch_dev

    def set_learning_rate(self,lr):
        self.lr = lr

    def set_optimizer(self, optimizer, optimizer_args):
        self._optimizer = optimizer
        self._optimizer_args = optimizer_args

    def set_loss_function(self, loss_F):
        self._loss_F = loss_F

    def get_model_state_dict(self):
        assert self._model is not None, "Device has no NN model"
        return self._model.state_dict()
    
    def set_model_state_dict(self, model_dict, strict=True):
        self._model.load_state_dict(copy.deepcopy(model_dict), strict=strict)

    def return_reference(self):
        return self

    def _check_trainable(self):
        assert self._model is not None, "device has no NN model"
        assert self._train_data is not None, "device has no training dataset"
        assert self._torch_device is not None, "No torch_device is set"
        assert self._optimizer is not None, "No optimizer is set"
        assert self._loss_F is not None, "No loss function is set"

    def _check_testable(self):
        assert self._model is not None, "device has no NN model"
        assert self._test_data is not  None, "device has no test dataset"
        assert self._torch_device is not None, "No torch_device is set"

    @staticmethod
    def correct_predictions(labels, outputs):
        res = (torch.argmax(outputs.cpu().detach(), axis=1) == 
              labels.cpu().detach()).sum()
        return res
    
    @staticmethod
    def correct_predictions_f1(labels, output):
        res = f1_score(labels.cpu(), torch.argmax(output.cpu(), axis=1), average='macro')*output.shape[0]
        return res

    @staticmethod
    def assert_if_nan(*tensors):
        for tensor in tensors:
            if torch.isnan(tensor).any():
                print("Error: loss got NaN")
                assert False, ""

    def _train(self, n_epochs=1):
        self._model.to(self._torch_device)
        self._model.train()

        loss_function = self._loss_F()

        trainloader = torch.utils.data.DataLoader(self._train_data,
                                    batch_size=self._batch_size_train,
                                    shuffle=True, pin_memory=True)

        optimizer = self._optimizer(self._model.parameters(), lr=self.lr, **self._optimizer_args)
        for _ in range(n_epochs):
            for _, (inputs, labels) in enumerate(trainloader):
                inputs, labels = inputs.to(self._torch_device), labels.to(self._torch_device)

                optimizer.zero_grad()
                output = self._model(inputs)
                loss = loss_function(output, labels)
                self.assert_if_nan(loss)
                loss.backward()
                optimizer.step()

    def _test(self):
        self._model.to(self._torch_device)
        self._model.eval()

        assert not self._test_data.dataset.train, "Wrong dataset for testing."

        testloader = torch.utils.data.DataLoader(self._test_data, shuffle=True,
                                                batch_size=self._batch_size_test, pin_memory=True)

        correct_predictions = 0

        with torch.no_grad():
            for _, (inputs, labels) in enumerate(testloader):
                inputs, labels = inputs.to(self._torch_device) , labels.to(self._torch_device)

                output = self._model(inputs)

                if not self.is_unbalanced:
                    correct_predictions += self.correct_predictions(labels, output)
                else:
                    correct_predictions += self.correct_predictions_f1(labels, output)

            self._accuracy_test = correct_predictions/len(self._test_data)

    @abstractmethod
    def _round(self):
        raise NotImplementedError

    def round(self):
        self._round()
        self._model.to('cpu')
        return self



class FLServerBase(ABC):
    
    _device_class = None
    _device_evaluation_class = None
    is_unbalanced = False

    def __init__(self, storage_path):
        self._devices_list = []
        self._storage_path = storage_path

        #general
        self.torch_device = None
        self.n_rounds = 0

        #devices
        self.n_active_devices = 0
        self.n_devices = 0
        self._device_constraints = None
        self._global_model = None

        #debug related
        self._plotting_f = None
        self._plotting_arg = None
        self.progress_output = True
        self.n_rounds_between_plot = 25

        self._measurements_dict = {}
        self._measurements_dict['accuracy'] = []
        self._measurements_dict['accuracy_top5'] = []
        self._measurements_dict['data_upload'] = []

        #nni related
        self.report_intermediate_f = None
        self.report_final_f = None
        self._main_performance_metric = None

        #training related
        self._optimizer = None
        self._optimizer_args = None
        self.lr = - 1.0
        self.lr_schedule = None

        #data related
        self._test_data = None
        self._train_data = None
        self.split_f = None

        self._seed_n = 0


    def set_seed(self, seed):
        torch.manual_seed(seed)
        random.seed(seed)
        self._seed_n = seed

    def set_optimizer(self, optimizer, optimizer_args):
        self._optimizer = optimizer
        self._optimizer_args = optimizer_args

    def set_plotting_callback(self, f, arg):
        self._plotting_f = f
        self._plotting_arg = arg

    def set_device_constraints(self, device_constraints):
        self._device_constraints = device_constraints

    @staticmethod
    def random_device_selection(n_devices, n_active_devices, rng):
        dev_idxs = rng.permutation(n_devices)[0:n_active_devices].tolist()
        return dev_idxs

    @staticmethod
    def count_data_footprint(state_dict):
        counted_bytes = 0
        for key in state_dict:
            param = state_dict[key]
            if isinstance(param, torch.Tensor):
                val = 4
                for i in range(len(param.shape)):
                    val *= param.shape[i]
                counted_bytes += val
        return counted_bytes

    def initialize(self):

        idxs_list = self.split_f(self._train_data, self.n_devices)

        self._evaluation_device = self._device_evaluation_class(0, self._storage_path)
        self._evaluation_device.set_seed(self._seed_n)
        self._evaluation_device.set_model(self._model_evaluation, **self._model_evaluation_args)
        self._evaluation_device.init_model()
        self._evaluation_device.set_test_data(self._test_data)
        self._evaluation_device.set_torch_device(self.torch_device)
        self._evaluation_device._batch_size_test = 1024
        self._evaluation_device.is_unbalanced = self.is_unbalanced

        self._devices_list = [self._device_class(i, self._storage_path) for i in range(self.n_devices)]

        for i, device in enumerate(self._devices_list):
            device.set_seed(self._seed_n)
            device.set_model(self._model[i], **self._model_args[i])
            device.set_train_data(torch.utils.data.Subset(self._train_data.dataset, idxs_list[i]))
            device.set_learning_rate(self.lr)
            device.set_loss_function(torch.nn.CrossEntropyLoss)
            device.set_optimizer(self._optimizer, self._optimizer_args)
            device.set_torch_device(self.torch_device)

            if self._device_constraints is not None:
                device.resources = self._device_constraints[i]

        self._devices_list[0].init_model()
        self._global_model = copy.deepcopy(self._devices_list[0]._model.state_dict())
        self._devices_list[0].del_model()

        return

    def save_dict_to_json(self, filename, input_dict):
        with open(self._storage_path + '/' + filename, 'w') as fd:
            json.dump(input_dict, fd, indent=4)

    def set_dataset(self, dataset, path,  *args, **kwargs):
        data = dataset(path, train=True, *args, **kwargs)
        data_test = dataset(path, train=False, *args, **kwargs)

        self._train_data = torch.utils.data.Subset(data, torch.arange(0, len(data)))
        self._test_data = torch.utils.data.Subset(data_test, torch.arange(0, len(data_test)))
        return 

    def set_model(self, model_list, kwargs_list):
        self._model = model_list
        self._model_args = kwargs_list

    def set_model_evaluation(self, model, kwargs):
        self._model_evaluation = model
        self._model_evaluation_args = kwargs

    def learning_rate_scheduling(self, round_n):
        if self.lr_schedule is not None:
            for schedule in self.lr_schedule:
                assert schedule[1] < 1.0
                assert isinstance(schedule[0], int) 
                if round_n == schedule[0]:
                    for device in self._devices_list:
                        device.lr = schedule[1]
                    log = f'[FL_BASE]: learning_rate reduction at ' + \
                        f'round {round_n} to {schedule[1]}'
                    logging.info(log)

    @abstractmethod
    def pre_round(self, round_n, rng):
        raise NotImplemented

    @abstractmethod
    def post_round(self, round_n, idxs):
        raise NotImplemented

    def run(self):
        self.check_device_data()
        print(f"#Samples on devices: {[len(dev._train_data) for dev in self._devices_list]}")
        logging.info(f"[FL_BASE]: #Samples on devices: {[len(dev._train_data) for dev in self._devices_list]}")
        
        rng = np.random.default_rng(self._seed_n)

        tbar =  tqdm.tqdm(iterable=range(self.n_rounds), total=self.n_rounds, file=sys.stdout, disable=not self.progress_output)
        for round_n in tbar:
            
            #learning rate scheduling
            self.learning_rate_scheduling(round_n)

            #selection of devices
            idxs = self.pre_round(round_n, rng)

            #init NN models
            for dev_idx in idxs:
                self._devices_list[dev_idx].init_model()
                self._devices_list[dev_idx].set_model_state_dict(self._global_model)

            worker_tasks = [dev.round for dev in [self._devices_list[i] for i in idxs]]

            for task in worker_tasks:
                task()
        
            #knwoledge aggregation // global model gets set
            self.post_round(round_n, idxs)

            #del models
            for dev_idx in idxs:
                self._devices_list[dev_idx].del_model()

            #save accuracy dict
            self.save_dict_to_json('measurements.json', self._measurements_dict)

            if self.progress_output:
                tbar.set_description(f"round_n {round_n}, acc: {self._measurements_dict['accuracy'][round_n][0]}")
            else:
                print(f"round_n {round_n}, acc={self._measurements_dict['accuracy'][round_n][0]}")

            #plotting
            if (round_n % self.n_rounds_between_plot) == 0 and round_n != 0:
                if self._plotting_f is not None:
                    try:
                        self._plotting_f(self._plotting_arg)
                    except:
                        print("Error plotting!")

            #report intermediate performance metric
            if self.report_intermediate_f is not None:
                self.report_intermediate_f(self._main_performance_metric)

        #report final performance
        if self.report_final_f is not None:
            self.report_final_f(self._main_performance_metric)

    def check_device_data(self):
        for i in range(len(self._devices_list)):
            for j in range(len(self._devices_list)):
                if i != j:
                    assert not torch.equal(self._devices_list[i]._train_data.indices,
                                        self._devices_list[j]._train_data.indices), "Devices do not exclusivly have access to their data!"
