from federatedscope.register import register_trainer
from federatedscope.core.trainers import BaseTrainer

try:
    import habana_frameworks.torch.core as htcore
except ImportError:
    htcore = None

# An example for converting torch training process to FS training process

# Refer to `federatedscope.core.trainers.BaseTrainer` for interface.

# Try with FEMNIST:
#  python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml \
#  trainer.type mytorchtrainer federate.sample_client_rate 0.01 \
#  federate.total_round_num 5 eval.best_res_update_round_wise_key test_loss


class MyTorchTrainer(BaseTrainer):
    def __init__(self, model, data, device, **kwargs):
        import torch
        # NN modules
        self.model = model
        # FS `ClientData` or your own data
        self.data = data
        # Device name
        self.device = device
        # kwargs
        self.kwargs = kwargs
        # Criterion & Optimizer
        self.criterion = torch.nn.CrossEntropyLoss()

    def train(self):
        import torch
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=0.001,
                                         momentum=0.9,
                                         weight_decay=1e-4)

        # _hook_on_fit_start_init
        self.model.to(self.device)
        self.model.train()

        total_loss = num_samples = 0
        # _hook_on_batch_start_init
        for x, y in self.data['train']:
            # _hook_on_batch_forward
            x, y = x.to(self.device), y.to(self.device)
            outputs = self.model(x)
            loss = self.criterion(outputs, y)

            # _hook_on_batch_backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if htcore is not None:
                htcore.mark_step()

            # _hook_on_batch_end
            total_loss += loss.item() * y.shape[0]
            num_samples += y.shape[0]

        # _hook_on_fit_end
        return num_samples, self.model.cpu().state_dict(), \
            {'loss_total': total_loss, 'avg_loss': total_loss/float(
                num_samples)}

    def evaluate(self, target_data_split_name='test'):
        import torch
        with torch.no_grad():
            self.model.to(self.device)
            self.model.eval()
            total_loss = num_samples = 0
            # _hook_on_batch_start_init
            for x, y in self.data[target_data_split_name]:
                # _hook_on_batch_forward
                x, y = x.to(self.device), y.to(self.device)
                pred = self.model(x)
                loss = self.criterion(pred, y)

                # _hook_on_batch_end
                total_loss += loss.item() * y.shape[0]
                num_samples += y.shape[0]

            # _hook_on_fit_end
            return {
                f'{target_data_split_name}_loss': total_loss,
                f'{target_data_split_name}_total': num_samples,
                f'{target_data_split_name}_avg_loss': total_loss /
                float(num_samples)
            }

    def update(self, model_parameters, strict=False):
        self.model.load_state_dict(model_parameters, strict)

    def get_model_para(self):
        return self.model.cpu().state_dict()


def call_my_torch_trainer(trainer_type):
    if trainer_type == 'mytorchtrainer':
        return MyTorchTrainer


register_trainer('mytorchtrainer', call_my_torch_trainer)
