import copy
import torch
import logging
import math

from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer

from typing import Type

logger = logging.getLogger(__name__)


def wrap_Simple_tuning_Trainer(
        base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
    # ---------------------------------------------------------------------- #
    # Simple_tuning method:
    # https://arxiv.org/abs/2302.01677
    # Only tuning the linear classifier and freeze the feature extractor
    # the key is to reinitialize the linear classifier
    # ---------------------------------------------------------------------- #
    init_Simple_tuning_ctx(base_trainer)

    base_trainer.register_hook_in_ft(new_hook=hook_on_fit_start_simple_tuning,
                                     trigger="on_fit_start",
                                     insert_pos=-1)

    return base_trainer


def init_Simple_tuning_ctx(base_trainer):

    ctx = base_trainer.ctx
    cfg = base_trainer.cfg

    ctx.epoch_linear = cfg.finetune.epoch_linear

    ctx.num_train_epoch = ctx.epoch_linear

    ctx.epoch_number = 0

    ctx.lr_linear = cfg.finetune.lr_linear
    ctx.weight_decay = cfg.finetune.weight_decay

    ctx.local_param = cfg.finetune.local_param

    ctx.local_update_param = []

    for name, param in ctx.model.named_parameters():
        if name.split(".")[0] in ctx.local_param:
            ctx.local_update_param.append(param)


def hook_on_fit_start_simple_tuning(ctx):

    ctx.num_train_epoch = ctx.epoch_linear
    ctx.epoch_number = 0

    ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param,
                                               lr=ctx.lr_linear,
                                               momentum=0,
                                               weight_decay=ctx.weight_decay)

    for name, param in ctx.model.named_parameters():
        if name.split(".")[0] in ctx.local_param:
            if name.split(".")[1] == 'weight':
                stdv = 1. / math.sqrt(param.size(-1))
                param.data.uniform_(-stdv, stdv)
            else:
                param.data.uniform_(-stdv, stdv)
            param.requires_grad = True
        else:
            param.requires_grad = False

    ctx.optimizer = ctx.optimizer_for_linear
