import copy
import torch
import logging

from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer

from typing import Type

logger = logging.getLogger(__name__)


def wrap_FedRepTrainer(
        base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
    # ---------------------------------------------------------------------- #
    # FedRep method:
    # https://arxiv.org/abs/2102.07078
    # First training linear classifier and then feature extractor
    # Linear classifier: local_param; feature extractor: global_param
    # ---------------------------------------------------------------------- #
    init_FedRep_ctx(base_trainer)

    base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_fedrep,
                                        trigger="on_fit_start",
                                        insert_pos=-1)

    base_trainer.register_hook_in_train(new_hook=hook_on_epoch_start_fedrep,
                                        trigger="on_epoch_start",
                                        insert_pos=-1)

    return base_trainer


def init_FedRep_ctx(base_trainer):

    ctx = base_trainer.ctx
    cfg = base_trainer.cfg

    ctx.epoch_feature = cfg.personalization.epoch_feature
    ctx.epoch_linear = cfg.personalization.epoch_linear

    ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear

    ctx.epoch_number = 0

    ctx.lr_feature = cfg.personalization.lr_feature
    ctx.lr_linear = cfg.personalization.lr_linear
    ctx.weight_decay = cfg.personalization.weight_decay

    ctx.local_param = cfg.personalization.local_param

    ctx.local_update_param = []
    ctx.global_update_param = []

    for name, param in ctx.model.named_parameters():
        if name.split(".")[0] in ctx.local_param:
            ctx.local_update_param.append(param)
        else:
            ctx.global_update_param.append(param)


def hook_on_fit_start_fedrep(ctx):

    ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear
    ctx.epoch_number = 0

    ctx.optimizer_for_feature = torch.optim.SGD(ctx.global_update_param,
                                                lr=ctx.lr_feature,
                                                momentum=0,
                                                weight_decay=ctx.weight_decay)
    ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
                                               lr=ctx.lr_linear,
                                               momentum=0,
                                               weight_decay=ctx.weight_decay)

    for name, param in ctx.model.named_parameters():

        if name.split(".")[0] in ctx.local_param:
            param.requires_grad = True
        else:
            param.requires_grad = False

    ctx.optimizer = ctx.optimizer_for_linear


def hook_on_epoch_start_fedrep(ctx):

    ctx.epoch_number += 1

    if ctx.epoch_number == ctx.epoch_linear + 1:

        for name, param in ctx.model.named_parameters():

            if name.split(".")[0] in ctx.local_param:
                param.requires_grad = False
            else:
                param.requires_grad = True

        ctx.optimizer = ctx.optimizer_for_feature
